{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from evals import *\n",
    "from optimization import *\n",
    "from gauss_update import *\n",
    "from kinetic_model import *\n",
    "import cubic_spline_planner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.collections.PathCollection at 0x7f146eab2a20>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "ax = [0.0, 6.0, 12.5, 10.0, 7.5, 3.0, -1.0]\n",
    "ay = [0.0, -3.0, -5.0, 6.5, 3.0, 5.0, -2.0]\n",
    "goal = [ax[-1], ay[-1]]\n",
    "\n",
    "cx, cy, cyaw, ck, s = cubic_spline_planner.calc_spline_course(\n",
    "        ax, ay, ds=0.2)\n",
    "\n",
    "plt.scatter(cx,cy)\n",
    "plt.scatter(ax,ay,c='r')\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class deviation_error_(torch.autograd.Function):\n",
    "    \n",
    "    @staticmethod\n",
    "    def forward(ctx,states,vs,waypoints,start):\n",
    "        states = states.data.clone()\n",
    "        states.requires_grad = True\n",
    "        \n",
    "        _range = torch.arange(len(states)).view(-1,1)\n",
    "        range_ = _range.t()\n",
    "        mask = (_range<=range_).float()\n",
    "        \n",
    "        inds = (mask.t()@vs.view(-1,1)).view(-1).round().long()+start\n",
    "        #print(inds)\n",
    "        #print([vs[:i+1].sum().data+start for i in range(len(vs))])\n",
    "        dw = waypoints[inds+1]-waypoints[inds]\n",
    "        \n",
    "        refs = waypoints[inds].data.clone().requires_grad_()\n",
    "        with torch.enable_grad():\n",
    "            error = error_deviation_parallel_(states,refs,inds,ql,qc)\n",
    "            de_s,de_w = torch.autograd.grad(error,[states]+[refs],grad_outputs=None,retain_graph=False,create_graph=False)\n",
    " \n",
    "        ctx.save_for_backward(states.data.clone(),vs.data.clone(),waypoints,de_s.data.clone(),de_w.data.clone(),dw.data.clone())\n",
    "        \n",
    "        return error.data.clone()\n",
    "    \n",
    "    @staticmethod\n",
    "    def backward(ctx,de):\n",
    "        states,vs,waypoints,de_s,de_w,dw = ctx.saved_tensors\n",
    "         \n",
    "        de_theta = (de_w@dw.t()).diag().view(-1,1)\n",
    "        _range = torch.arange(len(states)).view(-1,1)\n",
    "        range_ = _range.t()\n",
    "        mask = (_range<=range_).float()cx = torch.Tensor(cx).view(-1,1)\n",
    "cy = torch.Tensor(cy).view(-1,1)\n",
    "\n",
    "gx = cx[1:]-cx[:-1]\n",
    "gy = cy[1:]-cy[:-1]\n",
    "\n",
    "\n",
    "waypoints = torch.cat([cx[:-1],cy[:-1],gx,gy],dim=1)\n",
    "\n",
    "start=150\n",
    "\n",
    "x0,y0,gx0,gy0 = waypoints[start]\n",
    "v0 = torch.sqrt(gx0.pow(2)+gy0.pow(2))*3\n",
    "yaw0 = torch.atan2(gy0,gx0)\n",
    "delta0=torch.zeros(1)\n",
    "a0=torch.zeros(1)\n",
    "state0 = torch.cat(  [x0.view(1),y0.view(1),yaw0.view(1),v0.view(1),delta0,a0])\n",
    "control0 = torch.Tensor([0.0001,0.0001,0.001])\n",
    "        \n",
    "        de_v = (mask@de_theta).view(-1)\n",
    "        #print(de_v)\n",
    "        return de_s.data.clone(),de_v.data.clone(),None,None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([7.1618, 2.9268, 3.1397, 0.5057, 0.0000, 0.0000])"
      ]
     },
     "execution_count": 109,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cx = torch.Tensor(cx).view(-1,1)\n",
    "cy = torch.Tensor(cy).view(-1,1)\n",
    "\n",
    "gx = cx[1:]-cx[:-1]\n",
    "gy = cy[1:]-cy[:-1]\n",
    "\n",
    "\n",
    "waypoints = torch.cat([cx[:-1],cy[:-1],gx,gy],dim=1)\n",
    "\n",
    "start=150\n",
    "\n",
    "x0,y0,gx0,gy0 = waypoints[start]\n",
    "v0 = torch.sqrt(gx0.pow(2)+gy0.pow(2))*3\n",
    "yaw0 = torch.atan2(gy0,gx0)\n",
    "delta0=torch.zeros(1)\n",
    "a0=torch.zeros(1)\n",
    "state0 = torch.cat(  [x0.view(1),y0.view(1),yaw0.view(1),v0.view(1),delta0,a0])\n",
    "control0 = torch.Tensor([0.0001,0.0001,0.001])\n",
    "state0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "metadata": {},
   "outputs": [],
   "source": [
    "state0[3]=0.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([3.0001, 3.0001, 3.0001, 3.0001, 3.0001, 3.0001, 3.0001, 3.0001, 3.0001,\n",
      "        3.0001])\n",
      "0.7396644 -0.30000994\n",
      "0.13342646 -0.30000994\n",
      "0.08984161 -0.30000994\n",
      "0.07481848 -0.30000994\n",
      "0.0706311 -0.30000994\n",
      "0.06259345 -0.30000994\n",
      "0.051519826 -0.30000994\n",
      "0.046565443 -0.30000994\n",
      "0.04194828 -0.30000994\n",
      "0.038141076 -0.30000994\n",
      "0.034598093 -0.30000994\n",
      "0.031222954 -0.30000994\n",
      "0.02798053 -0.30000994\n",
      "0.024869721 -0.30000994\n",
      "0.021899631 -0.30000994\n",
      "0.019091371 -0.30000994\n",
      "0.016481498 -0.30000994\n",
      "0.014122797 -0.30000994\n",
      "0.012078099 -0.30000994\n",
      "0.010403979 -0.30000994\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.collections.PathCollection at 0x7f146062b1d0>"
      ]
     },
     "execution_count": 110,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAD4CAYAAADmWv3KAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAZCUlEQVR4nO3de4xcZ3nH8e/TxMabe8su0MZ21qIpFaJAyApBU9GIJCiF1KGFP3ALgklUq+KSQBNvE9IShYqKrikhCpDKImwoxE6rEIRbUcChjSqkQrw24RYnJQIncZrUs70AcUzjtE//ODPx8Xhmd2bnzLyX8/tIq/XOzp55vTvznGd+73vOMXdHRETy83OhByAiIqOhAi8ikikVeBGRTKnAi4hkSgVeRCRTJ4Z40MnJSZ+eng7x0CIiydqzZ8+iu0/1e/8gBX56epqFhYUQDy0ikiwze3iQ+yuiERHJlAq8iEimVOBFRDKlAi8ikikVeBGRTKnAi4hkSgVeRCRTKvAiMj6Li7B1Kzz4YPF5cTH0iLIW5EAnEamZxUWYn4dDh+CGG+Cee+BLXyq+12gU39u4EXbuLL6enAw63FyowIvI6HQW9uuvh7m5opiff/7R4j47e7ToHzoEJ5+sQl8BFXgRqV6vwl4u2lu2FJ8bjeJzu+gfOlQU/Pb35udV7FdIBV5EqtfuyrsV9k6Tk0eL/ZYtxc6h3cG3t6OufkVU4EWkGu2uvdE42pWvpCCXC357O+Wuvv09WZYKvIgMpzOOgaIIV1GI28W+3dVv3FisvlEn3xctkxSR4bRjFDgax1StXeh37iwea36++sfIkDp4ERnOMHHMSh9LnXxf1MGLyMq0D1qCorseR6FVJz+QSgq8mZ1hZnea2QNmts/MXl3FdkUkYu1oJkSRbTSOxkHtHY2Oij1OVRHNTcCX3f3NZrYaOKmi7YpIbNqTqhs3Fl+PInNfTnmlzdatWmHTw9AF3sxOB14DvAPA3Z8Gnh52uyISqfKkagwFVbl8T1V08BuAJjBvZi8D9gBXuvuh8p3MbDOwGWD9+vUVPKyIBFGeVI1Bu5tXJ3+cKjL4E4FXALe4+znAIeCazju5+zZ3n3H3mampqQoeVkTGKsSk6iDKubwA1RT4A8ABd/9m6+s7KQq+iOQk5KRqPzpzeU26Dh/RuPsTZvaomb3I3R8ELgDuH35oIhKV2KKZXmKbIwioqlU07wFub62g+SEQ+TNARPpWPsdMCgUzlR3RGFSyDt7d72vl6y919ze6+39VsV0RiUDs0UwnRTXP0qkKRGRpqXbEimpU4EWki3IsU+6IU5LqjqlCOheNiBwvtVimG0U16uBFpIucut8aRzUq8DXSbML+/TA9DTrWTLpKbcVMP3LaWQ1IEU1N7NgBZ50FF11UfN6xI/SIJEo5RDOdahzVqINPXD9debMJl18Ohw8XH1B8feGF6uSlQ87dbg2jGnXwCeu3K9+/H1avPva2VauK20WA+M8zU4UanqtGBT5yzSbs3l187ry93ZWf+ONF3nV4K7OXLbL4QOniB60X7YZTFzn1fxa5mq08l+Lt6ZEjRdcvAuQZzXSqYVSjiCZiO3YURXz1anj6afjsjYu86SfFBNiB++Aqn+cWGjSYZyuzrHE4/Eng5tmjG5mdZRLY+Ttwzo5Z1jwH/vLntvDZGxeZum2+67mzNRlbQzlHM51qFNWowEeq3aGfdHiRxuF55mmw593zvOmZ4ol59pPwZz+b5WfAfOvUP3dYgyvfCazj2Bdqo8E5DXjybPjd32hwxcspinv7Sd5oPLtyYseuyWN2KrfeCps2jfW/LiGkejDTStRpZ+buY/8499xzXZbQbPoj75nzDac2/Wrm3MGvZs6nTylu92bTvdn0vZvmfO2app92mvvEhPv27YM9hs/NHf0M/tMPzPnEhPtzKR73uTR9YsL94MGR/U8ltPLzQKIHLPgAtVYdfIzm51l38yxvORG2tbrzeRo89b+TrPnTLdBKVM7ZvoW9K41Tyh1bq5P5wcsbrL4JGoeLyAdg26otHLhvkan7usc5krgaxRVddZ6SITMq8LEoP9FaBffc0xo89b5Jtq3awpEjRVzSWcSnpirIyVvFfm2ziGXmSzuVI0fg7K/PwwePj3OYnFRen7o6xRXd5L6DG6Tdr+pDEU1J+y3y9de7Q/HvkoMH3e+9d3wxyfbtRdxzTOzTJc7xubln73v66SuIiCQcxTJHJfa7YMCIRgU+tHbBvP76aJ5oS+5UWi+I5r6m8vpUlXbS0pJIoR+0wCuiCaFLHBNTBrhk7NOKc360u1hp05nX79+vqCZ6dY9lusk0qlGBD6HzyZTgE2p6untev+HURdia76RV0nI8kVhVMt3pqcCPU/sFtnFj8XXCT6apqWLS9/LLj50Envy7pTshTcoGlGmXWolMjwNQgR+nzF5gmzYVJyw7pmAvdnRCpa5RB1EFlmmXWpkcl0wOEthX9VGrSdby5E0iEzmV6jiICo5+aFJWopLA5DOaZI1MBnn7UDoOojrp8CINilMvHFk1qUnZccixMx2FDN/hqMCPSkZ5+1A6DqL6A46uuvnEkS06o+U4ZBYNjkyGOXxlBd7MTgAWgMfc/ZKqtpssvaiO0Z6Unb2swRovTox2660wZVp1M3IZdqYjldE7nio7+CuBfcBpFW4zXXpRHaeYlJ1k//4tXDHdima2akc4Ep1FSr/b/mXUnFVS4M1sLfAG4EPAH1WxzWRprfGSjjuIqtuOMKMOKpiMitTYZdScVdXBfwyYBU7tdQcz2wxsBli/fn1FDxshvbAG06271O9weBkVqbHL6B3P0JfsM7NLgIPuvmep+7n7NnefcfeZqZyXTdTwuo+VK/0Om0341q5Fnry+HpdYG1odrq06LouLyV/ar4prsp4HbDSz/cAdwGvN7HMVbDctemFVp9VB7dg1yVlnwV0b5znlg7N864qMrxdalTpcW3VcMvhdDh3RuPu1wLUAZnY+cLW7v3XY7SZHsUKlyhcVv4UGPwPu+EKDvc3Wyhtl9N0pmqlOBr9LrYMfRq+zQsrQ9u8vTmlw+DD8B5N8hC2ctrq4feoe7UyPoRUzo5HB77LSAu/u9wD3VLnNqNX9KNURap+tsuzIkeL243amdV91o3eP0oM6+GGoax+Zo2erhFWr6LhkYUdnVfcCp+fhaCXcQFQxyVo/mlAdi02b4OGH4e67i889zzzZuXIpg9UPyyr/H9tRgp6Ho5HwZKs6+JWoe8c4Rn1dVLwzK63D36cO/8dYJPwOSQV+JRL+g9dCrhm9JvXDSHiyVRHNIBTNJKHpk+w+fwtNb/19En6LfYzy/0OxzPglGP2pgx+E3hZHb8cOjr9qVModvbr2eCT4+leBH4ReYFErHxx1+HBx2+WXw4UPTzKVakavpbjxSPD1rwI/iISzuDooHxzVtmoVx181KvaOXl17nBJ8/SuD70eC2VsdLXlwVFlnft2Z0Yf+eytrl4qowPcjl0m6zLUPjpqYgNNOKz4fPThqCa119Iu/3WD3bnjy5qULfrMJu3cXnytTfgydkTReoXf+gxrkCt1VfZx77rnVX258lJrN4krrzWbokUgfDh50v/fe4nO/tm93n5hwP/1097Vrmr53U+nvPTfnDu5zc759e/H9P1kz52vXNH379iEGWn5elR5DIhb47wQs+AC1VgV+OSru2Tt4sCjucPRjYqK0g2g9B5r7mj4x4X41xYv8auZ8YsK9ua/jOVJ+znQ+f3oVdT3P0hD47zRogdck63JSWnEhK7Ls5GwrB//R7uJ+84eL6GSeBqtWweFPzsPNpedI+TkDxz5/yt8rT6AmOIFXS4n9nVTgl6NVDNnrd3K2fb8ft05fDDBxBCbe2YB1HP9cKT9nun0vsWIhJbGtvOrBiq5/vGZmZnxhYWHsjyvSS/sAqfKZK7ud3Kzf+0nmtm4t3onNzY11J21me9x9pu/7q8D3kMgeWqrTbBaxzPT00itv+r2fZCxQfRi0wCui6UXZe+30debKAe4nGUskXlOB70XZu4gkTgc6ddIZI0WkHwkc9KQC30lHrYpIPxKoFYpoOimaEZF+JFArtIpGRCQRg66iUUTTlkCeJiIyCBX4tgTyNBGJUMTN4dAZvJmtA/4aeD7gwDZ3v2nY7Y5dAnmaiEQo4mNmqphkfQa4yt33mtmpwB4z2+Xu91ew7fFJ5MAFEYlMxM3h0BGNuz/u7ntb//4psA84c9jtjk3Eb69EJAERX3Wr0gzezKaBc4BvdvneZjNbMLOFZqWXwhmSsncRGVakjWJl6+DN7BTg88B73f0nnd93923ANiiWSVb1uEOL+O2ViCQi0hy+kgJvZqsoivvt7n5XFdscG2XvIjKsSBvFoSMaMzPgVmCfu390+CGNUaRvq0QkMZHm8FVk8OcBbwNea2b3tT5eX8F2R0/5u4hkbOiIxt2/DlgFYxm/SN9WiUiiIrtQUL2PZI30bZWIJCqyVKCeZ5OMbC8rIpmILBWoZwcf2V5WRDIRWSpQzw4+sr2siMgo1LODj2wvKyIZiWj5db0KfES/eBHJVEQRcL0imkgPJxaRjEQUAderwEf0ixeRTEV0+pN6FfiIfvEiIqNWnwxe+buIjEsk9aY+BT6iiQ8RyVwk9aY+EY3ydxEZl0jqjbmP/9obMzMzvrCwMPbHFRFJmZntcfeZfu9fn4hGRKRm8i/wkUx2iEjNRFB78i/wkUx2iEjNRFB78p9kjWSyQ0RqJoLao0lWEZFEaJK1LIIMTEQklLwLfAQZmIjUWOAmM+8MPoIMTERqLPAZbPMu8Dq5mIiEFLjJzLvAi4iEFLjJzDOD1+SqiMQiYD2qpMCb2cVm9qCZPWRm11SxzaFoclVEYhGwHg0d0ZjZCcAngIuAA8BuM9vp7vcPu+0V0+SqiMQiYD2qooN/JfCQu//Q3Z8G7gAurWC7K9fOvSYnl76fohwRGbV+69EIVFHgzwQeLX19oHXbMcxss5ktmNlCs9ms4GF7GKRod751UsEXkYyMbZLV3be5+4y7z0xNTY3ugQbJuxoNmJs7+tap/LMq9iJShYC1pIplko8B60pfr23dFsYgeVfnEqbyz3YeoLC4WNzWaAR5qyUiiQp4sFMVBX43cLaZbaAo7G8Bfq+C7a7MMOtOyz/buaMIfESaiCQq4CTr0AXe3Z8xs3cDXwFOAD7t7t8femShLdXdgzp6EelPwIOdKsng3f1L7v4r7v5Cd/9QFdsc2Khzrs6ZcOX1IhK5fI5kHffBBOUJWq3GEZGlBKoJ+ZyLZtw5V795fXsHoChHpL4CzeHlU+BDntRnkNU4IlI/gSZa87hkX8wTnp1ji3msIhK1el6yL+aTi2lyVkQCySOiSenkYjqYSqR+Ar228yjwKV25SZOzIvWjSdYa0uSsSD1oknWFco01NDkrIh3qN8ka8wTrMJaanAVN0IrIstKPaFKaYB2G8nqRtAV4F55+gU9pgnUYyutF0hbgdZpuga97Jr3Uapy6/25EYhQgbUg3g881e18J5fUi8QtwbdZ0O/i6ZO8roYuViMRHGfwA6pK9r8RSeb3iG5EwlMHLSJQL/tatOj2CSAjK4PukTHnlyhcqAeX1IuOiDL5PypRXbrlrzep3K5KNNAu8Jliro4uLi4xHgNdSmhFNgLc6taEllyKjEWBpd5odvIyPIhyRagRIHtI6m6TigvDKfwPQ30NkjMZ6Nkkz22pmD5jZd8zsC2Z2xjDbW5aOXg2vHOEovhHpX4DXx7ARzS7gWnd/xsz+ArgW+OPhh9WDJlfjovhGpH+pHejk7l8tffkN4M3DDWcZOno1LlqBI9K/AA1qlZOslwF/0+ubZrYZ2Aywfv36Ch9WotFZ8NXRixwVoEFdtsCb2d3AC7p86zp3/2LrPtcBzwC399qOu28DtkExybqi0Upa1NGLBLVsgXf3C5f6vpm9A7gEuMBDLMmReC3X0avgS52kdjZJM7sYmAV+092fqmZIki1NykqdpTbJCnwceA6wy8wAvuHufzj0qCRPmpSVOkttktXdf7mqgUgNaVJW6iTAJGua56KRPHWeyhh08JTkI8BzWQVe4tHtJHI6WlZyoZONiXTQxKzkQicbE1lG50SsJmalRsZ6sjGRsVvufPUisUrwZGMiYWmppaQiwXXwImFpqaWkIrV18CLR6fYiUlcvMdA6eJEh9bPUUmTcAi3vVQcv+VNOL6EFig5V4CV/yukltEBXo1OBl/pRRy/jFuhqdMrgpX60ll7GTRm8SCBaeSOjpgxeJJBub5+V00uVlMGLREQ5vVRJGbxIRJTTS1UCnuJaHbxIP5TTy0oFjPtU4EX6oZxeVipQ/g4q8CIrp65e+hEofwdl8CIrp/PeyHICX2JSHbxIlbT6RsoCx3gq8CJV0nlvpCxg/g4q8CKjpZy+viL4O1eSwZvZVWbmZqZnq0iZcvr6iuDvPHQHb2brgNcBjww/HJEaUFdfD4HjGaimg78RmAW8gm2J5E9dff4i2WEP1cGb2aXAY+7+bTNb7r6bgc0A69evH+ZhRfKj1Td5iWRyfdkCb2Z3Ay/o8q3rgPdTxDPLcvdtwDaAmZkZdfsiZVp9k5cI4hnoo8C7+4XdbjezXwM2AO3ufS2w18xe6e5PVDpKkbpRTp+m8t8ogh3ziiMad/8u8Lz212a2H5hx9zCHbInkROe+SVNkfyOtgxdJhbr6+EUSzbRVVuDdfbqqbYlIF+rq4xbhzlYdvEjKtPomHhHubFXgRVLWz+obFf3xiCyeARV4kbx0KzIRdpZZiWzlTJkKvEhOuuX0mpwdrYh3oCrwIrnT5OxotHeSGzcWX0cUzbSpwIvUkSZnh5fATlIFXqSOdGqE4UU4qdpJ12QVkaJIzc0dn9MHvJ5otNq/Fzj+rKCRUYEXkf5OYayCX0jo1M6KaESku84Iou5r7BOYVO2kAi8i3XXm9HVdY98u7IcOwQ03FLcl8n9VRCMi/ekW43Rm9znGOOWdWOc8ReTUwYvIyi23GifVCKc87vI7l5T+D6iDF5EqdXb0KU3UlsdWHne3dy6JUAcvItVZLrePscPvlrEnsMa9H+rgRWR0OrvfQTv8qjr+8nY6t9ktY0+4ay9TBy8i4zNoh79cx7/U1+3tNRrHbgeO3WbCGftyVOBFJJzlCv6gO4Behbxb5NL+d7eTsWXC3H3sDzozM+MLCwtjf1wRSdxKO/hMOnMz2+PuM33fXwVeRCQNgxZ4TbKKiGRKBV5EJFMq8CIimVKBFxHJlAq8iEimVOBFRDKlAi8ikqkg6+DNrAk8PPYH7m4SiPDUdl1prNVLZZygsY5CKuOEYqwnu/tUvz8QpMDHxMwWBjlwICSNtXqpjBM01lFIZZywsrEqohERyZQKvIhIplTgYVvoAQxAY61eKuMEjXUUUhknrGCstc/gRURypQ5eRCRTKvAiIpmqdYE3s4vN7EEze8jMrgk9nl7MbJ2Z/ZOZ3W9m3zezK0OPaSlmdoKZfcvM/j70WJZiZmeY2Z1m9oCZ7TOzV4ceUy9m9r7W3/57ZrbDzNaEHhOAmX3azA6a2fdKt/2Cme0ysx+0Pv98yDG29Rjr1tbf/ztm9gUzOyPkGNu6jbX0vavMzM1s2auY1LbAm9kJwCeA3wJeDGwysxeHHVVPzwBXufuLgVcB74p4rABXAvtCD6IPNwFfdvdfBV5GpGM2szOBK4AZd38JcALwlrCjetZtwMUdt10DfM3dzwa+1vo6Brdx/Fh3AS9x95cC/wpcO+5B9XAbx48VM1sHvA54pJ+N1LbAA68EHnL3H7r708AdwKWBx9SVuz/u7ntb//4pRSE6M+youjOztcAbgE+FHstSzOx04DXArQDu/rS7/3fYUS3pRGDCzE4ETgL+LfB4AHD3fwb+s+PmS4HPtP79GeCNYx1UD93G6u5fdfdnWl9+A1g79oF10eP3CnAjMAv0tTqmzgX+TODR0tcHiLRolpnZNHAO8M2wI+npYxRPwP8LPZBlbACawHwrTvqUmZ0celDduPtjwEcourbHgR+7+1fDjmpJz3f3x1v/fgJ4fsjBDOAy4B9CD6IXM7sUeMzdv93vz9S5wCfHzE4BPg+8191/Eno8nczsEuCgu+8JPZY+nAi8ArjF3c8BDhFPlHCMVoZ9KcVO6ZeAk83srWFH1R8v1mFHvxbbzK6jiEJvDz2WbszsJOD9wAcG+bk6F/jHgHWlr9e2bouSma2iKO63u/tdocfTw3nARjPbTxF5vdbMPhd2SD0dAA64e/ud0J0UBT9GFwI/cvemux8B7gJ+PfCYlvLvZvaLAK3PBwOPZ0lm9g7gEuD3Pd4Dg15IsYP/duv1tRbYa2YvWOqH6lzgdwNnm9kGM1tNMWm1M/CYujIzo8iK97n7R0OPpxd3v9bd17r7NMXv8x/dPcpO092fAB41sxe1broAuD/gkJbyCPAqMzup9Vy4gEgnhFt2Am9v/fvtwBcDjmVJZnYxRaS40d2fCj2eXtz9u+7+PHefbr2+DgCvaD2Pe6ptgW9NrLwb+ArFi+Vv3f37YUfV03nA2yg64vtaH68PPagMvAe43cy+A7wc+PPA4+mq9S7jTmAv8F2K120Uh9ib2Q7gX4AXmdkBM7sc+DBwkZn9gOLdx4dDjrGtx1g/DpwK7Gq9rv4q6CBbeox18O3E+45ERESGUdsOXkQkdyrwIiKZUoEXEcmUCryISKZU4EVEMqUCLyKSKRV4EZFM/T8jFx9sgQH9UwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "#初回\n",
    "\n",
    "len_horizon = 10\n",
    "EMAX=2000\n",
    "kappa=0.01\n",
    "TMAX = 1\n",
    "_controls = torch.nn.Parameter(torch.zeros(len_horizon,2))\n",
    "\n",
    "dt = torch.ones(len_horizon,1)\n",
    "vs = torch.nn.Parameter(torch.ones(len_horizon)).data.clone()*3.0001\n",
    "\n",
    "\n",
    "print(vs)\n",
    "opt = torch.optim.Adam([vs]+[_controls],lr=0.01)\n",
    "#print(controls.size())\n",
    "for T in range(TMAX):\n",
    "    if T>0:\n",
    "        state0 = s2_[0]\n",
    "        _controls = torch.nn.Parameter(controls2_.data.clone())\n",
    "        start = start2_\n",
    "    for epoch in range(EMAX):\n",
    "        controls = torch.cat([_controls,dt],dim=1)\n",
    "        s=state0\n",
    "        s_=[]\n",
    "        for t in range(len_horizon):\n",
    "            s = model(s,controls[t])\n",
    "            s_.append(s.view(1,-1))\n",
    "        s_ = torch.cat(s_,dim=0)\n",
    "        if epoch==0:\n",
    "            s0_ = s_.data.clone()\n",
    "        dev = deviation_error_.apply\n",
    "        dev_loss = dev(s_,vs,waypoints,start)\n",
    "        v_loss = -kappa*vs.sum()\n",
    "        loss = dev_loss + v_loss\n",
    "        if epoch % 100 ==0:\n",
    "            print(dev_loss.data.numpy(),v_loss.data.numpy())\n",
    "        opt.zero_grad()\n",
    "        loss.backward(retain_graph=True)\n",
    "        opt.step()\n",
    "x_ = s_[:,0].data.numpy()\n",
    "y_ = s_[:,1].data.numpy()\n",
    "    \n",
    "plt.scatter(x_,y_,c='b',s=20)\n",
    "plt.scatter(cx,cy,c='r',s=1)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 4.9153,  3.8062,  2.4008,  0.7471,  0.6306, -0.0991],\n",
      "       grad_fn=<SelectBackward>)\n",
      "Parameter containing:\n",
      "tensor([[ 0.0423, -0.1589],\n",
      "        [ 0.0579, -0.8063],\n",
      "        [-1.0263, -1.1404],\n",
      "        [ 1.0257,  1.1381],\n",
      "        [-0.0241, -0.7162],\n",
      "        [-0.0181,  0.1500],\n",
      "        [-0.0085,  0.1448],\n",
      "        [ 0.0000,  0.0000],\n",
      "        [ 0.0000,  0.0000],\n",
      "        [ 0.0000,  0.0000]], requires_grad=True)\n",
      "tensor([3.0001, 3.0001, 3.0001, 3.0001, 3.0001, 3.0001, 3.0001, 3.0001, 3.0001,\n",
      "        3.0001])\n",
      "162\n",
      "0.0006184621 -0.30000994\n",
      "0.00079016766 -0.30000994\n",
      "0.00062702526 -0.30000994\n",
      "0.0005359805 -0.30000994\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-108-74a3231f6b7d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     38\u001b[0m             \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdev_loss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mv_loss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     39\u001b[0m         \u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 40\u001b[0;31m         \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mretain_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     41\u001b[0m         \u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     42\u001b[0m \u001b[0mx_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0ms_\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m    105\u001b[0m                 \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    106\u001b[0m         \"\"\"\n\u001b[0;32m--> 107\u001b[0;31m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    108\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    109\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m     91\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[1;32m     92\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 93\u001b[0;31m         allow_unreachable=True)  # allow_unreachable flag\n\u001b[0m\u001b[1;32m     94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     95\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "len_horizon = 10\n",
    "EMAX=2000\n",
    "kappa=0.01\n",
    "TMAX = 1\n",
    "_controls = torch.nn.Parameter(torch.zeros(len_horizon,2))\n",
    "\n",
    "\n",
    "\n",
    "dt = torch.ones(len_horizon,1)\n",
    "vs = torch.nn.Parameter(torch.ones(len_horizon)).data.clone()*3.0001\n",
    "\n",
    "state0 = s2_[0]\n",
    "_controls = torch.nn.Parameter(controls2_.data.clone())\n",
    "start = start2_\n",
    "print(state0)\n",
    "print(_controls)\n",
    "\n",
    "print(vs)\n",
    "opt = torch.optim.Adam([vs]+[_controls],lr=0.01)\n",
    "#print(controls.size())\n",
    "print(start)\n",
    "for T in range(TMAX):\n",
    "    for epoch in range(EMAX):\n",
    "        controls = torch.cat([_controls,dt],dim=1)\n",
    "        s=state0\n",
    "        s_=[]\n",
    "        for t in range(len_horizon):\n",
    "            s = model(s,controls[t])\n",
    "            s_.append(s.view(1,-1))\n",
    "        s_ = torch.cat(s_,dim=0)\n",
    "        if epoch==0:\n",
    "            s0_ = s_.data.clone()\n",
    "        dev = deviation_error_.apply\n",
    "        dev_loss = dev(s_,vs,waypoints,start)\n",
    "        v_loss = -kappa*vs.sum()\n",
    "        loss = dev_loss + v_loss\n",
    "        if epoch % 100 ==0:\n",
    "            print(dev_loss.data.numpy(),v_loss.data.numpy())\n",
    "        opt.zero_grad()\n",
    "        loss.backward(retain_graph=True)\n",
    "        opt.step()\n",
    "x_ = s_[:,0].data.numpy()\n",
    "y_ = s_[:,1].data.numpy()\n",
    "    \n",
    "plt.scatter(x_,y_,c='b',s=20)\n",
    "plt.scatter(cx,cy,c='r',s=1)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([3.0001, 3.0001, 3.0001, 3.0001, 3.0001, 3.0001, 3.0001, 3.0001, 3.0001,\n",
       "         3.0001]),\n",
       " tensor([[ 5.5189,  3.2250,  2.3331,  0.5060, -0.1303, -0.3030],\n",
       "         [ 5.1695,  3.5910,  2.3066,  0.2031,  0.8414,  0.4011],\n",
       "         [ 5.0332,  3.7415,  2.3975,  0.6042,  0.7358,  0.1132],\n",
       "         [ 4.5886,  4.1508,  2.6163,  0.7174, -0.1112, -0.0818],\n",
       "         [ 3.9679,  4.5106,  2.5843,  0.6356, -2.3420, -0.1057],\n",
       "         [ 3.4285,  4.8468,  2.8458,  0.5299,  0.5118, -0.0430],\n",
       "         [ 2.9216,  5.0012,  2.9649,  0.4869,  1.1201, -0.0479],\n",
       "         [ 2.4423,  5.0868,  3.3673,  0.4390,  0.9907,  0.0083],\n",
       "         [ 2.0145,  4.9886,  3.6352,  0.4473,  0.9907,  0.0083],\n",
       "         [ 1.6206,  4.7767,  3.9082,  0.4556,  0.9907,  0.0083]],\n",
       "        grad_fn=<CatBackward>))"
      ]
     },
     "execution_count": 94,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vs,s_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 4.9010,  3.9226,  2.4992,  0.7205,  0.6883,  0.3641],\n",
       "        [ 4.3241,  4.3543,  2.7363,  1.0846,  0.6883,  0.3641],\n",
       "        [ 3.3274,  4.7819,  3.0931,  1.4486,  0.6883,  0.3641],\n",
       "        [ 1.8805,  4.8522,  3.5697,  1.8127,  0.6883,  0.3641],\n",
       "        [ 0.2314,  4.0997,  4.1661,  2.1768,  0.6883,  0.3641],\n",
       "        [-0.8995,  2.2398,  4.8822,  2.5408,  0.6883,  0.3641],\n",
       "        [-0.4700, -0.2645,  5.7182,  2.9049,  0.6883,  0.3641],\n",
       "        [ 1.9835, -1.8198,  6.6739,  3.2690,  0.6883,  0.3641],\n",
       "        [ 5.0060, -0.5748,  7.7494,  3.6331,  0.6883,  0.3641],\n",
       "        [ 5.3852,  3.0385,  8.9447,  3.9971,  0.6883,  0.3641]])"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s0_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[ 0.0423, -0.1589],\n",
       "         [ 0.0579, -0.8063],\n",
       "         [-1.0263, -1.1404],\n",
       "         [ 1.0257,  1.1381],\n",
       "         [-0.0241, -0.7162],\n",
       "         [-0.0181,  0.1500],\n",
       "         [-0.0085,  0.1448],\n",
       "         [ 0.0000,  0.0000],\n",
       "         [ 0.0000,  0.0000],\n",
       "         [ 0.0000,  0.0000]], grad_fn=<CopySlices>), Parameter containing:\n",
       " tensor([[-0.0084,  0.5544],\n",
       "         [ 0.0423, -0.1589],\n",
       "         [ 0.0579, -0.8063],\n",
       "         [-1.0263, -1.1404],\n",
       "         [ 1.0257,  1.1381],\n",
       "         [-0.0241, -0.7162],\n",
       "         [-0.0181,  0.1500],\n",
       "         [-0.0085,  0.1448],\n",
       "         [ 0.0000,  0.0000],\n",
       "         [ 0.0000,  0.0000]], requires_grad=True))"
      ]
     },
     "execution_count": 106,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAD4CAYAAADmWv3KAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAY9ElEQVR4nO3df5BdZX3H8fdXAiUBAi27SksIm7HU1rFaYOtomaGOgEMVAxb/gFZHLqlpx6qgsJEfUxl0ZOyuU3QqYycjbqxCaAdxRMdagi3TsVOFTfxNpDIaIBSavbZFjHGStN/+ce7Nnr25P/eee5/nPOfzmsns3t275z65u/d7v+fzPOccc3dERCQ9Lwg9ABERGQ0VeBGRRKnAi4gkSgVeRCRRKvAiIolaFeJBJyYmfGpqKsRDi4iU1s6dO+vuPtnv/YMU+KmpKRYWFkI8tIhIaZnZE4PcXxGNiEiiVOBFRBKlAi8ikigVeBGRRKnAi4gkSgVeRCRRKvAiIolSgReR8anXYW4u+9j8/LHHlr4mhQpyoJOIVEy9DvPzsH8/3Hrr0te3bIGHHoIvfzn73gknQK0GExPBhpoSFXgRGZ3Wwn7LLTA7mxXxpo0b4TWvye6zZYsKfYFU4EWkeN0Ke75oz8wsfazXs8LeLPSQ3X9+XsV+hVTgRaR48/NZke5U2NuZmFhe6JvFXV39iqnAi0gxml17rbYUwaykIDcLffPnYXlX3/ye9KQCLyLDaTeBOjNTTCFu7eo3bsxW3KiT74uWSYrIcJoxChw9gVqUZqG///7ssebni3+MBKmDF5HhDBPHrPSx1Mn3RR28iKxM80AlyLrrcRRadfIDKaTAm9kpZnavmf3AzHab2auL2K6IRKwZzYQosrXaUhyUPzpWlikqovkY8BV3f7OZHQesKWi7IhKb5qTqxo3Z7VFk7r3kV9rMzWmFTQdDF3gzOxk4H7gKwN0PAgeH3a6IRCo/qRpDQVUu31ERHfwGYBGYN7NXADuBa9x9f/5OZrYZ2Aywfv36Ah5WRILIT6rGoNnNq5M/ShEZ/CrgHOAT7n42sB+4ofVO7r7V3afdfXpycrKAhxWRsQoxqTqIfC4vQDEFfi+w192/0bh9L1nBF5GUhJxU7UdrLq9J1+EjGnd/1syeMrOXuPtjwAXAo8MPTUSiEls000lscwQBFbWK5l3AXY0VND8CIv8LEJG+5c8xU4aCWZY3ojEoZB28u3+rka+/3N0vc/f/LmK7IhKB2KOZVopqjtCpCkSku7J2xIpqVOBFpI18LJPviMukrG9MBdK5aETkaGWLZdpRVKMOXkTaSKn7rXBUowIvIkvKtmKmHym9WQ1IEU1FLC7CI49kH0U6SiGaaVXhqEYdfAIWF2HPHpiagnZngdi+HTZtguOOg4MH4c474corxz1KKYWUu90KRjXq4Etu+3Y480y46KLs4/bty7+/uJgV9wMH4Lnnso+bNqmTlxaxn2emCBU8V40KfOQ6Riv1Oj+7ZY4tV9dZc6DO25+bY82BOluuzr5OvQ71Or/44BynrapzKnWuZ45TqXPssVnHL3JEitFMq2ZUMzFRmYuEKKKJWGu08pnb61z+08YE2Pw8J35gC1cdD78A5sh2PY93OPEDW+DEbBtn/PUWrlgFh3L3uePQDBtOqsPc/FI3k1vz3CvykQSlHM20U5G4RgU+Us1oZc2BOrUD88xTY+c757n8cOOPslbjZz+DbbM1DjR+Zp4aqw3e9344MfdCPXdtjWuvzYr/PVbjzjth4ou5P3A48vn2dTPK66uorAczrVRV3tDcfez/zj33XJcuFhf9yXfN+oaTFv16Zt3Br2fWp07Mvu6Li0fuevfd7qtXu69dm328++72m9y3z/3hh7OPzcfw2ca2Gp8v7l701avdTyV73FPJbh/5GUlP/u9Aogcs+AC1VgU+oKOKbtNsVtRvXDXbV7HtuJ0BPfyw+8kn+7I3lbVr3Xc9oCKQrMbfms/Ohh5JGCV7gxu0wCuiCaRrvt7YbTx3bY2fv2eCrcfOcOhQFpe0y8QnJ4vJyqemsrHMN872PE+NQ4fgrK/NwweWoqFl5yih9zJNiVhVoopOUs/iB3k3KOpf1Tv4ffuyOAWW4pAPHnNL206qqO68X20jn3yX09LxNe9/8sndIyKJSMm61pEq2XOBIpr4NaMQWIpDbjvulqPy9VC6vqnkXhD79rmvO34pQgJXZl8GVY9l2ilJoR+0wCuiCWDDSXXesX+erdSOxCH3vKDGn/zFBERwfEnXyCe32mLPI3AV83ywsfzyI8wcWWOvqCZiVY9l2kk0qlGBD2Dii/PcdngLrII71sxwx6GZjvl6zKamYBs1fsFSbr/2YJ3f+tI8bKileTRkmaV4IrGiJPqmpwI/YssmIK3xAtu4EYD3vrHGm54v7+Tk5CTMfmqCTZuyzn31Ibj/svmlA606FBFNygaSaJdaiESPA1CBH6HWlTL/etk8Z29feoFNEEUiM5Qrr4QLL8y/idXgbJY6oZYrA+nEZwEl2qUWpvUqVikYJLAv6l8VJlmbK2Xy69jXHb/oz78//omcQuUm9PKrh5r/NCkr0SjB5DOaZI3Dnj1Zl1o7MH/kHDBbj5th9yUz/G4izUFfcl3jnh/DaavqXE526oWfMKFJ2XFIsTMdhQT3cFTgR6S5UuZesry9edDQ1FTYcY1dLtuccrjiwDy35VbdVPI5GTdl7/1JMIcvrMCb2THAAvC0u19S1HbLqnWlTLcjUatichLO/XiNm94J9x5fY/X/ZkfwTm5TdzlSCXamI5XQHk+RHfw1wG5gbYHbLK/Gi6nsK2WKdvmfTnD+H87wpj2N52SbusuRaC1Sem77l9AeTyEF3szWAW8APgS8t4htlk1z6d+Gk+rZqXgba41TWClTtGUHUrXrLhPqoIJJqEiNXUJ7PEV18B8FtgAndbqDmW0GNgOsX7++oIeNQ37p3zv2N6IZ0AurH+26SxWn4SVUpMYuoT2eoS/ZZ2aXAPvcfWe3+7n7VnefdvfpyYSyitZrnm49XOOmVbPU36gX1orlrp25uAjf3JG7DKF0V4Vrq45LApf1K+KarOcBG81sD3AP8Foz+2wB2y2F5nLI5jVPIZtU/fHzemGtWKOD2r5jgjPPhPs2ZkfHfvPdCV8vtChVuLbquCTwXA4d0bj7jcCNAGb2GuB6d3/LsNsti+Y51N/O/LJrnmrp33Dye0afaJzv5p7P19i1mDvlgzL6oymaKU4Cz6XWwQ+jni3x+8ztNa69trbsmqcJpVBBNPeMDhyAnzDBR5hh7XGNg6IeUka/jFbMjEYCz2WhBd7dHwIeKnKbUWvswl0+C+c/OcOePTO8e0rFvQjNPaO8IwdFtXZWVV91o0lp6UAd/DByhWZyQoW9SJOT2YFhmzbBscfScqBYS2dV9QKXQJQQtRI3EEVMslaPViqMxZVXwhNPwIMPZh87nnUyt+oGSGL1Q0/5/2MzStDf4WiUeLJVHfxKVL1jHKO+LijempVW4fdThf9jLEq8h6QCvxIl/oVXQqoZff7/ob/B8SnxZKsimkEomimH1siixLvYy+T/H4plxq+E0Z86+EFot7icytzRq2uPRwlf/yrwg9ALrJx6ZfQxF/zWsZaksCSphK9/FfhBlDiLk5zWF2psnZm69jiV8PWvAt+PmDs8GVzrCzW2CEdduxREk6z9SGWSTtrrNSk7jsm1/GO0ruuXeJRsolUdfD+0m1wtvSKcojr8/HbUtZdDbHFeDyrw3eRfgCX4ZUpBekU4vQp+/nbz/u2+l9+OmohyKNnvSQW+m5K9W8uIDFrw87eh8/fy2ynhBF4llez3pALfTcnerWVMehX8dn837b5XsmIhOaEn4vtk7j72B52envaFhYWxP66ISCHm5rI9sdnZsb5Jm9lOd5/u9/7q4DspyTu0iARQkr17LZPsREsjRaSTkpwLSB18JyV5hxYR6UQdfCudMVJE+lGCg55U4FspmhGRfpSgViiiaaVoRkT6UYJaoWWSIiIlMegySUU0TSXI00REBqEC31SCPE1EIhRxczh0Bm9mZwB/C7wIcGCru39s2O2OXQnyNBGJUMTnrCpikvUwcJ277zKzk4CdZrbD3R8tYNvjo/OCiMhKRNwcDh3RuPsz7r6r8fnzwG7g9GG3OzYR716JSAlEfFRroRm8mU0BZwPfaPO9zWa2YGYLi4uLRT7scJS9i8iwIm0UC1sHb2YnAp8DrnX3n7Z+3923AlshWyZZ1OMOLeLdKxEpiUhz+EIKvJkdS1bc73L3+4rY5tgoexeRYUXaKA4d0ZiZAXcCu939r4Yf0hhFulslIiUTaQ5fRAZ/HvBW4LVm9q3Gv9cXsN3RU/4uIgkbOqJx968BVsBYxi/S3SoRKanILhRU7SNZI92tEpGSiiwVqObZJCN7lxWRRESWClSzg4/sXVZEEhFZKlDNDj6yd1kRkVGoZgcf2busiCQkouXX1SrwET3xIpKoiCLgakU0kR5OLCIJiSgCrlaBj+iJF5FERXT6k2oV+IieeBGRUatOBq/8XUTGJZJ6U50CH9HEh4gkLpJ6U52IRvm7iIxLJPXG3Md/7Y3p6WlfWFgY++OKiJSZme109+l+71+diEZEpGLSL/CRTHaISMVEUHvSL/CRTHaISMVEUHvSn2SNZLJDRComgtqjSVYRkZLQJGteBBmYiEgoaRf4CDIwEamwwE1m2hl8BBmYiFRY4DPYpl3gdXIxEQkpcJOZdoEXEQkpcJOZZgavyVURiUXAelRIgTezi83sMTN73MxuKGKbQ9HkqojEImA9GjqiMbNjgDuAi4C9wCNmdr+7PzrstldMk6siEouA9aiIDv6VwOPu/iN3PwjcA1xawHZXrpl7TUx0v5+iHBEZtX7r0QgUUeBPB57K3d7b+NoyZrbZzBbMbGFxcbGAh+1gkKLduuukgi8iCRnbJKu7b3X3aXefnpycHN0DDZJ31WowO7u065T/WRV7ESlCwFpSxDLJp4EzcrfXNb4WxiB5V+sSpvzPth6gUK9nX6vVguxqiUhJBTzYqYgC/whwlpltICvsVwB/VMB2V2aYdaf5n219owh8RJqIlFTASdZCziZpZq8HPgocA3zK3T/U7f6lPJtkvoMHdfMiMnZBzibp7l92999w9xf3Ku4jM+qcKz8TrslZESmBdI5kHefBBN0mZ0EFX0SWC1QT0jkXzThzrm6Ts7A8r29O2CrOEamuQHN46RT4kCf1GWQ1johUT6CJ1jQu2RfzEsbWscU8VhGJWjUv2RfzycVaD1PWwVQiMiZpRDRlOrmYDqYSqZ5Ar+00CnyZrtykg6lEqkeTrBXUbXJW3bxIOgKlDOXP4FPKsXUwlUiaAp0yuPwdfKqxRq/4Rh2+iPRQ/gJfpgnWQehgKpG0BGjKyl/gyzTBOgwdTCVSbgFep+Ut8FWPKLqtxqn6cyMSowBpQ3knWWM+uGncuh1MBZqgFYlBgInW8nbwqWbvRdD6epH4KIMfQFWy95XQ+nqR+CiDl5HIF/y5OS23FAlBGXyflCmvnC5WIhKGMvg+KVNeuUHW1+u5FSm1chZ4TbAWp1fBV4QjUowAr6VyRjSBzutQCVpyKTIaAZZ2l7ODl/FRhCNSjADJQ7ku2ae4ILz87wD0+xAZo7Fess/M5szsB2b2HTP7vJmdMsz2etLRq+HplMYiKxPg9TFsRLMDuNHdD5vZXwI3Au8bflgdaHI1LopvRPpXtgOd3P2B3M2vA28ebjg96OjVuGgFjkj/AjSoRU6yXg38XadvmtlmYDPA+vXrC3xYiUZrwVdHL7IkQIPas8Cb2YPAaW2+dbO7f6Fxn5uBw8Bdnbbj7luBrZBNsq5otFIu6uhFgupZ4N39wm7fN7OrgEuACzzEkhyJV6+OXgVfqqRsZ5M0s4uBLcDvu/vPixmSJEuTslJlZZtkBT4O/BKww8wAvu7ufzb0qCRNmpSVKivbJKu7/3pRA5EK0qSsVEmASdZynotG0tR6KmPQwVOSjgB/yyrwEo92J5HT0bKSCp1sTKSFJmYlFTrZmEgPrROxmpiVChnrycZExq7X+epFYlXCk42JhKWlllIWJVwHLxKWllpKWZRtHbxIdNq9iNTVSwy0Dl5kSP0stRQZt0DLe9XBS/qU00togaJDFXhJn3J6CS3Q1ehU4KV61NHLuAW6Gp0yeKkeraWXcVMGLxKIVt7IqCmDFwmk3e6zcnopkjJ4kYgop5ciKYMXiYhyeilKwFNcq4MX6YdyelmpgHGfCrxIP5TTy0oFyt9BBV5k5dTVSz8C5e+gDF5k5XTeG+kl8CUm1cGLFEmrbyQvcIynAi9SJJ33RvIC5u+gAi8yWsrpqyuC33MhGbyZXWdmbmb6axXJU05fXRH8nofu4M3sDOB1wJPDD0ekAtTVV0PgeAaK6eBvB7YAXsC2RNKnrj59kbxhD9XBm9mlwNPu/m0z63XfzcBmgPXr1w/zsCLp0eqbtEQyud6zwJvZg8Bpbb51M3ATWTzTk7tvBbYCTE9Pq9sXydPqm7REEM9AHwXe3S9s93Uz+21gA9Ds3tcBu8zsle7+bKGjFKka5fTlFvDo1bwVZ/Du/l13f6G7T7n7FLAXOEfFXaQAyunLKfCRq620Dl6kLNTVxy+yaK2wAt/o4kVkVHRGy7jV67B/P9xyS/DsvUknGxMps1oNZmeXr76JKCKolPl5uPVWOOGEaPamFNGIlJlW38QjkpUzeergRVLS2tGDuvpRaz6/cPTEeGAq8CIp6Xf1jYp+cSJe3aSIRiR17aIDRTnFiTCaaVKBF0ldu9U3OjXC8PLPWaRvkopoRKqoNcqJOGaIVgmeM3XwIqKDqAbRfF42bsxuRxjNNKnAi0h/B1Gp4GdKNH+hAi8i7bV29e0KW5WKfok69yYVeBFpr7Wrr+pqnGZh378/O1IVSvN/1SSriPSn3Rr7KpwqIf8m1noQWeTUwYvIyvU6VUJZI5z8uPN7LmX6P6AOXkSK1NrRty4ljLnDz48tP+52ey4loQ5eRIrTK7ePscNvl7FHfHTqINTBi8jotHa/vTp8WN5JF9Xxd9tmu4y9xF17njp4ERmfQVfmQPeOv9vt5rZqte7bLHHG3osKvIiE0895cvKft0Y83W7D0ufdthnJBbJHwdx97A86PT3tCwsLY39cESm5lXbwiXTmZrbT3af7vr8KvIhIOQxa4DXJKiKSKBV4EZFEqcCLiCRKBV5EJFEq8CIiiVKBFxFJlAq8iEiigqyDN7NF4ImxP3B7E0CEp7ZrS2MtXlnGCRrrKJRlnJCN9QR3n+z3B4IU+JiY2cIgBw6EpLEWryzjBI11FMoyTljZWBXRiIgkSgVeRCRRKvCwNfQABqCxFq8s4wSNdRTKMk5YwVgrn8GLiKRKHbyISKJU4EVEElXpAm9mF5vZY2b2uJndEHo8nZjZGWb2z2b2qJl938yuCT2mbszsGDP7ppl9KfRYujGzU8zsXjP7gZntNrNXhx5TJ2b2nsbv/ntmtt3Mjg89piYz+5SZ7TOz7+W+9itmtsPMftj4+Mshx9gYU7txzjV+/98xs8+b2Skhx9jUbqy5711nZm5mPa9iUtkCb2bHAHcAfwC8FLjSzF4adlQdHQauc/eXAq8C/jzisQJcA+wOPYg+fAz4irv/JvAKIh2zmZ0OvBuYdveXAccAV4Qd1TLbgItbvnYD8FV3Pwv4auN2aNs4epw7gJe5+8uBfwduHPegOtjG0WPFzM4AXgc82c9GKlvggVcCj7v7j9z9IHAPcGngMbXl7s+4+67G58+TFaLTw46qPTNbB7wB+GTosXRjZicD5wN3Arj7QXf/n7Cj6moVsNrMVgFrgP8IPJ4j3P1fgP9q+fKlwKcbn38auGysg2qj3Tjd/QF3P9y4+XVg3dgH1kaH5xTgdmAL0NfqmCoX+NOBp3K39xJp0cwzsyngbOAbYUfS0UfJ/gD/L/RAetgALALzjTjpk2Z2QuhBtePuTwMfIevangGec/cHwo6qpxe5+zONz58FXhRyMH26GviH0IPoxMwuBZ5292/3+zNVLvClY2YnAp8DrnX3n4YeTyszuwTY5+47Q4+lD6uAc4BPuPvZwH7iiBGO0sivLyV7U/o14AQze0vYUfXPs7XYUa/HNrObyaLQu0KPpR0zWwPcBLx/kJ+rcoF/Gjgjd3td42tRMrNjyYr7Xe5+X+jxdHAesNHM9pBFXq81s8+GHVJHe4G97t7cE7qXrODH6ELgx+6+6O6HgPuA3ws8pl7+08x+FaDxcV/g8XRkZlcBlwB/7PEeGPRisjf4bzdeX+uAXWZ2WrcfqnKBfwQ4y8w2mNlxZJNW9wceU1tmZmRZ8W53/6vQ4+nE3W9093XuPkX2fP6Tu0fZabr7s8BTZvaSxpcuAB4NOKRungReZWZrGn8LFxDphHDO/cDbGp+/DfhCwLF0ZGYXk0WKG93956HH04m7f9fdX+juU43X117gnMbfcUeVLfCNiZV3Av9I9mL5e3f/fthRdXQe8FayjvhbjX+vDz2oBLwLuMvMvgP8DnBb4PG01djLuBfYBXyX7HUbzSH2ZrYd+DfgJWa218w2AR8GLjKzH5LtgXw45Bih4zg/DpwE7Gi8rv4m6CAbOox18O3Eu0ciIiLDqGwHLyKSOhV4EZFEqcCLiCRKBV5EJFEq8CIiiVKBFxFJlAq8iEii/h+J8nD5g35sfAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "step = 0\n",
    "s2_ = s_[step:]\n",
    "x_ = s2_[:,0].data.numpy()\n",
    "y_ = s2_[:,1].data.numpy()\n",
    "\n",
    "controls2_ = _controls.data.clone()\n",
    "controls2_[:-(step+1)] = _controls[step+1:]\n",
    "controls2_[-(step+1):]=0\n",
    "    \n",
    "start2_ = start + (step +1)*3\n",
    "plt.scatter(x_,y_,c='b',s=20)\n",
    "plt.scatter(cx,cy,c='r',s=1)\n",
    "controls2_,_controls\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Parameter containing:\n",
       " tensor([1.8064, 2.2475, 2.6507, 2.8838, 3.0197, 3.0864, 3.1329, 3.1753, 3.1753,\n",
       "         3.1753], requires_grad=True), Parameter containing:\n",
       " tensor([[-0.0636,  0.5025],\n",
       "         [-0.0854,  0.4851],\n",
       "         [-0.0546,  0.3112],\n",
       "         [-0.0061,  0.1693],\n",
       "         [ 0.0114,  0.0673],\n",
       "         [-0.0148, -0.0405],\n",
       "         [ 0.0000,  0.0000],\n",
       "         [ 0.0000,  0.0000],\n",
       "         [ 0.0000,  0.0000],\n",
       "         [ 0.0000,  0.0000]], requires_grad=True))"
      ]
     },
     "execution_count": 98,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "state0_ = s2_[0]\n",
    "start_ = (start + vs[:3].sum().round().int())\n",
    "controls_ = torch.cat( [_controls[2:],torch.zeros_like(_controls[:2])],dim=0)\n",
    "#controls_[:,2]=1\n",
    "controls_ = torch.nn.Parameter(controls_.data.clone())\n",
    "vs_ = torch.zeros_like(vs)\n",
    "vs_[:-2]=vs[2:]\n",
    "vs_[-2:]=vs[-1]\n",
    "vs_ = torch.nn.Parameter(vs_.data.clone())\n",
    "vs_,controls_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 6.3231,  2.9579,  2.6732,  0.8252, -0.5831,  0.1357],\n",
       "       grad_fn=<SelectBackward>)"
      ]
     },
     "execution_count": 94,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "state0_ = s2_[0]\n",
    "start_ = (start + vs[:3].sum().round().int())\n",
    "state0_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 6.3231,  2.9579,  2.6732,  0.8252, -0.5831,  0.1357],\n",
      "       grad_fn=<SelectBackward>)\n",
      "Parameter containing:\n",
      "tensor([[-0.0636,  0.5025],\n",
      "        [-0.0854,  0.4851],\n",
      "        [-0.0546,  0.3112],\n",
      "        [-0.0061,  0.1693],\n",
      "        [ 0.0114,  0.0673],\n",
      "        [-0.0148, -0.0405],\n",
      "        [ 0.0000,  0.0000],\n",
      "        [ 0.0000,  0.0000],\n",
      "        [ 0.0000,  0.0000],\n",
      "        [ 0.0000,  0.0000]], requires_grad=True)\n",
      "Parameter containing:\n",
      "tensor([0.8000, 0.8000, 0.8000, 0.8000, 0.8000, 0.8000, 0.8000, 0.8000, 0.8000,\n",
      "        0.8000], requires_grad=True)\n",
      "6.667714 -0.80000013\n",
      "1.916621 -0.87454194\n",
      "1.2979115 -0.91283685\n",
      "1.1404617 -0.9428396\n",
      "0.96841985 -0.9706794\n",
      "0.7207286 -0.99533665\n",
      "0.6464384 -1.0122913\n",
      "0.5910181 -1.0235221\n",
      "0.56214887 -1.0318235\n",
      "0.47636652 -1.0420655\n",
      "0.4047236 -1.0545623\n",
      "0.36866122 -1.0686553\n",
      "0.32675186 -1.0746917\n",
      "0.31590736 -1.0798696\n",
      "0.28388274 -1.0842228\n",
      "0.24586266 -1.0910887\n",
      "0.22480865 -1.0989162\n",
      "0.20326872 -1.1125824\n",
      "0.16586623 -1.1272522\n",
      "0.15343851 -1.134539\n",
      "0.15883766 -1.137222\n",
      "0.13490349 -1.1414528\n",
      "0.13315864 -1.1488398\n",
      "0.13239609 -1.1554948\n",
      "0.1340699 -1.1640435\n",
      "0.109812476 -1.168508\n",
      "0.08777419 -1.1718596\n",
      "0.071525484 -1.1757532\n",
      "0.070363745 -1.1819375\n",
      "0.069355465 -1.1899847\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.collections.PathCollection at 0x7f59a0627710>"
      ]
     },
     "execution_count": 99,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAD4CAYAAADmWv3KAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAXrElEQVR4nO3dfYxldX3H8c+Xh9Vll0rLjNqyuxliqa2x6uJotBhLBA1VMmjjH2A1ctlkbdIqtrhbHqSEKo2d9QmVWDfiqJFAG9S6MdaKtqQxqYZZwCdWKtFVlkKZsa2yK7IQv/3j3MucuXuf77n393Der2Qz92nO/e2de77ne76/h2PuLgBAfo4L3QAAwGQQ4AEgUwR4AMgUAR4AMkWAB4BMnRDiTWdmZnxubi7EWwNAsvbv37/q7rODvj5IgJ+bm9Py8nKItwaAZJnZj4d5PSUaAMgUAR4AMkWAB4BMEeABIFMEeADIFAEeADJFgAeATBHgAUzP6qq0Z0/xs3X73nvXHkOlgkx0AlAzq6vS0pJ05Ih07bVrj+/eLd1+u/SlLxXPbdokNRrSzEywpuaEAA9gctoD+zXXSIuLRRBvWViQzj67eM3u3QT6ChHgAVSvV2AvB+1du9Z+rq4Wgb0V6KXi9UtLBPsREeABVG9pqQjS3QJ7JzMz6wN9K7iT1Y+MAA+gGq2svdFYK8GMEpBbgb71+9L6rL71HPoiwAMYT6cO1F27qgnE7Vn9wkIx4oZMfiAMkwQwnlYZRTq2A7UqrUC/b1/xXktL1b9HhsjgAYxnnHLMqO9FJj8QMngAo2lNVJKK7HoagZZMfiiVBHgzO8XMbjWz75vZATN7aRXbBRCxVmkmRJBtNNbKQeXZsVinqhLN9ZK+7O6vN7MNkk6qaLsAYtPqVF1YKO5PoubeT3mkzZ49jLDpYuwAb2ZPk/RySRdLkrsflXR03O0CiFS5UzWGgEpdvqsqMvjTJa1IWjKz50vaL+lSdz9SfpGZ7ZS0U5K2bdtWwdsCCKLcqRqDVjZPJn+MKmrwJ0g6U9JH3X27pCOSLm9/kbvvdfd5d5+fnZ2t4G0BTFWITtVhlOvykFRNgD8k6ZC7f7N5/1YVAR9ATkJ2qg6ivS5Pp+v4JRp3f8jM7jezZ7v7vZLOkXTP+E0DEJXYSjPdxNZHEFBVo2jeKumm5giaH0qK/BsAYGDlNWZSCJipHIimoJJx8O5+d7O+/jx3f627/28V2wUQgdhLM+0o1TyJpQoA9JZqRkyphgAPoINyWaacEack1QNThViLBsCxUivLdEKphgweQAc5Zb81LtUQ4AGsSW3EzCByOlgNiRINgDU5lGbatUo1MzO1W3mSDB7Amtyz3ZqVawjwAPIszXSS+wGsDSWaFLWfZpbvD3IbaJdjaaaTmo2sIYNPSber15dPO6X+t1tXqW9lbNL6Mc+on5pltnUp1RDgY1cOxK0v5TXXrF8WtdPO2e92t4NCxl929JDqZKZR1eSAZu4+9Tedn5/35eXlqb9vMjoF9VZAryrT7pXBt89iRL74WyfFzPa7+/ygryeDj1E5uy5nGlVmWe3bKt9uP30lCOSrJqWKrjL/bhPgY1H+ok0qqA+q/fS1/YCT8Q5ROzUpVXSV+QGOAB9at47TkF+29oNKOQhkvkPUQg4LiVUl8wMcwyRDKwfMWK8nWZ4J2H7dy7bhlysr0h13SAcOFD9XVgK2G53VZUjkIDIfNkkGH0Kvckzs2jO+0gHq5i27tGNH8fCjj0obNxa3r7xSestbJK61HonMs9aR5Hpm6u5T//fCF77Qa21x0V0qfqZuZcV9cdFXDqz4xo3up2rF36FFP1UrLvmT/zZscH/Xu9wffjh0g2us+bfylZXQLYlPIp+NpGUfItZSopmmVjljYSHecsywmhn9jx6Z0YYNUkNL2qPdamj96f/Ro9LVV0tbt0rvfjclnCAozXRXLkNmhBLNNOV6Gihpbq4I4kvN6623fp6qVTW0pCU19FPN6LHHikB/9dVrJZwbb5QuuihQw+uE0kx/mQ2bJIOftHInZHsHZUZmZ4tA/YuNM7ph4y79VMXO0S2jl4o6/aOPSjt2kMlPRaZZaqUyO8shg5+09qw9s8y97KKLpHPPlQ4eLLL5s8+Wlp7ondFL0oknFr9DJ+yEZJaVTlRmZzkE+Elp7VQLC8X9TL4w/czOrgXqT39auuSSGb33l2sHtVZGL0nvVfH4448XJR5MSMalwcplNi+gsgBvZsdLWpb0gLufX9V2k8VO9WRG/7GPSdddV4ynWXqsoeOOk5Z+1dDGjdKpvqp9r13SrDUkkV1ORGZZ6cRldMZTZQZ/qaQDkn6twm2mi51KUpHNv/OdxTj4gwelzZtndPjwLl28WTp8WPq9Ly5p89/slrbrmAPhykrxO3NzlG+GxmzV0WWUnFUS4M1si6TXSLpO0l9Wsc1k1eXKOEMql27WOb0hbdb6A+Hqqu5625IWPt/QI0+Z0dGjjLQZWkZBauoySs6qyuA/KGm3pJO7vcDMdkraKUnbtm2r6G0jxI41nA7Z5eEPL2n7zbt1ofRk/X7HjqLcQyY/oIyC1NRldMYz9jBJMztf0sPuvr/X69x9r7vPu/v8bM57acZDIaflBy9r6OqnLq4beXPZr/bo0N15rRMyEa1huRJDIseVwWUuqxgHf5akBTM7KOkWSa8ws89UsN10lL8IjDUe25YXzOh9tn4s/bse260zvp7H2OSJymwcd1AZfJZjl2jc/QpJV0iSmZ0t6R3u/sZxt5sUyjKVak2a2rGjGCN/y9GG3vA6aftbSytYZjLKoXKUZqqTwWfJOPhx1HSs+zSUJ03Nzc1odrbHFafqjhEzk5HBZ1lpgHf32yXdXuU2o0agmaiuI2/aM6u6Z/R8D9EFGfw4MjiFS1KPNelrGeD4Hk5WwgkEi42NgpEKcelzlaks0bE/PQl3tpLBj6LuGWNsumT0hw9LB87fledMWL6D05PwGRIBfhQJ/8FrodHQXXdJC4sNPXK9dPJjq9r3uiVt/1B6p9jrdLvUIyYr4c5WSjTDoDSThBWf0Vn/tEuHfjmjn/1MuvCXxczYwx9O7xR7nXKpgLLM9CVY+iODHwanxUk4eFDasKG4mIhUrEX/1KdIf/yyhrZLaXWakbXHI8H9nwA/DHawJLQuH9jyU83ofcft0tte0HygfUeNOeDX6IIx0Utw/yfADyPhWlydtM+Effzx4v6THa3tO2psmRlZe5wS3P8J8IOIOcNDR+tnwraNomnfUWObOEXWjorQyTqIhMfB1tnsrPSiFw0wRLK9w7L97z2NzrWaXJw9eYl1tJLBD4LT5HrpV8KpKsMvb4esPQ2xlfP6IMD3wtWZ6qlfCadfwC/fb72+03Pl7ZBEpCGxvxMBvpfEjtaYkGEDfvm+1P258nYS7MCrpcT+TgT4XhI7WmNK+gX8Tt+bTs8lFixQErojfkDm7lN/0/n5eV9eXp76+wJAJfbsKc7EFhenepA2s/3uPj/o68ngu0nkCA0ggETO7hkm2Q1DIwF0k8haQGTw3SRyhAaAbsjg27FiJIBBRT7xiQDfjtIMgEFFHi8o0bSjNANgUJHHC4ZJAkAihh0mSYmmJfJaGgAMiwDfEnktDUCkIk4Ox67Bm9lWSZ+W9AxJLmmvu18/7nanLvJaGoBIRbxmVRWdrE9Iuszd7zSzkyXtN7Pb3P2eCrY9PawLAmAUESeHY5do3P1Bd7+zefsRSQcknTbudqcm4tMrAAmIeFZrpTV4M5uTtF3SNzs8t9PMls1seWVlpcq3HQ+1dwDjijRRrGwcvJltlvRZSW9395+3P+/ueyXtlYphklW979giPr0CkIhI6/CVBHgzO1FFcL/J3T9XxTanhto7gHFFmiiOXaIxM5N0o6QD7v7+8Zs0RZGeVgFITKR1+Cpq8GdJepOkV5jZ3c1/r65gu5NH/R1AxsYu0bj71yVZBW2ZvkhPqwAkKrILBdV7Jmukp1UAEhVZVaCeq0lGdpQFkInIqgL1zOAjO8oCyERkVYF6ZvCRHWUBYBLqmcFHdpQFkJGIhl/XK8BH9MEDyFREJeB6lWginU4MICMRlYDrFeAj+uABZCqi5U/qFeAj+uABYNLqU4On/g5gWiKJN/UJ8BF1fADIXCTxpj4lGurvAKYlknhj7tO/9sb8/LwvLy9P/X0BIGVmtt/d5wd9fX1KNABQM/kH+Eg6OwDUTASxJ/8AH0lnB4CaiSD25N/JGklnB4CaiSD20MkKAImgk7UsghoYAISSd4CPoAYGoMYCJ5l51+AjqIEBqLHAK9jmHeBZXAxASIGTzLwDPACEFDjJzLMGT+cqgFgEjEeVBHgzO8/M7jWz+8zs8iq2ORY6VwHEImA8GrtEY2bHS7pB0islHZJ0h5ntc/d7xt32yOhcBRCLgPGoigz+xZLuc/cfuvtRSbdIuqCC7Y6uVfeamen9Oko5ACZt0Hg0AVUE+NMk3V+6f6j52DpmttPMls1seWVlpYK37WKYoN1+6kTAB5CRqXWyuvted5939/nZ2dnJvdEw9a5GQ1pcXDt1Kv8uwR5AFQLGkiqGST4gaWvp/pbmY2EMU+9qH8JU/t32CQqrq8VjjUaQUy0AiQo42amKAH+HpDPM7HQVgf1CSW+oYLujGWfcafl32w8UgWekAUhUwE7WSlaTNLNXS/qgpOMlfcLdr+v1+iRXkyxn8BLZPICpC7KapLt/yd1/x92f1S+4T8yk61zlnnA6ZwEkIJ+ZrNOcTEDnLIBhBIoL+axFM806F52zAIYRqA8vnwAfclEfOmcB9BKoozWPS/bFnCW3ty3mtgKIWj0v2Rfz4mLt05Sp1wOYkjxKNCktLka9HqinAPt3HgE+pSs3Ua8H6inA/p1HgE9Vr9E4ZPNAXgJUGtKuwedWw2YyFZCvAMsGp53B51zS6Fe+IcMH0EfaAT6lztVh9SrfSOsDfqvDlmAPxItO1iGl1Lk6rmFmzwKID52sQ6h7iaLXaJy6fzZAjOhkHULMk5umrddkKokOWiAGdLIOIef6+7gYXw/Ehxr8EOpUfx8W4+uB+FCDx0SUA/6ePQy3BEKgBj8gasqj63WxEonPFpgUavADoqY8umHG1/PZAklLM8DTwVqdfgGfEg5QjQD7UpolmgCnOrXB+vXAZAQY2p1mBo/pYcYsUI0AlYe0LtlHuSAsLj8IBDXVS/aZ2R4z+76ZfdvMPm9mp4yzvb6YvRoWM2aB0QXYP8Yt0dwm6Qp3f8LM/k7SFZL+avxmdUHnalwYgQMMLrWJTu7+ldLdb0h6/XjN6YPZq3FhBA4wuAAJapWdrJdI+oduT5rZTkk7JWnbtm0Vvi2i0R7wyeiBNQES1L4B3sy+KumZHZ66yt2/0HzNVZKekHRTt+24+15Je6Wik3Wk1iItZPRAUH0DvLuf2+t5M7tY0vmSzvEQQ3IQr34ZPQEfdZLaapJmdp6k3ZL+0N1/UU2TkC06ZVFnqXWySvqIpKdIus3MJOkb7v6nY7cKeaJTFnWWWieru/92VQ1BDdEpizoJ0Mma5lo0yFP7UsYSk6eQjwDfZQI84tFpETlmyyIXLDYGtKFjFrlgsTGgDxY8Q41NdbExYOr6LXgGxCrBxcaAsBhqiVQkOA4eCIuhlkhFauPggeh02onI6hEDxsEDYxpkqCUwbYGG95LBI3/U6RFaoNIhAR75o06P0AJdjY4Aj/oho8e0BboaHTV41A9j6TFt1OCBQBh5g0mjBg8E0un0mTo9qkQNHogIWT2qRA0eiAjj6VGVgEtck8EDg2L0DUYRsNxHgAcGxXh6jCJQ/V0iwAOjo06PQQSqv0vU4IHRUadHP4EvMUkGD1SJOj3KApfxCPBAlajToyxg/V0iwAOTRZ2+viL4O1dSgzezy8zMzYxvK1BGnb6+Ivg7j53Bm9lWSa+S9JPxmwPUAFl9PQQuz0jVZPAfkLRbklewLSB/ZPX5i+SAPVYGb2YXSHrA3b9lZv1eu1PSTknatm3bOG8L5IfRN3mJpHO9b4A3s69KemaHp66SdKWK8kxf7r5X0l5Jmp+fJ9sHyhh9k5cIyjPSAAHe3c/t9LiZ/b6k0yW1svctku40sxe7+0OVthKoG+r0aQs4e7Vs5Bq8u3/H3Z/u7nPuPifpkKQzCe5ABajTpynwzNV2jIMHUkFWH7/ISmuVBfhmFg9gUrjyVNxWV6UjR6Rrrglee29hsTEgZY2GtLi4fvRNRCWCWllakq69Vtq0KZqzKUo0QMoYfROPSEbOlJHBAzlpz+glsvpJa32+0rEd44ER4IGcDDr6hqBfnYhHN1GiAXLXqXRAKac6EZZmWgjwQO46jb5haYTxlT+zSA+SlGiAOmov5URcZohWAp8ZGTwAJlENo/W5LCwU9yMszbQQ4AEMNomKgF9IqP+CAA+gs/asvlNgq1PQTyhzbyHAA+isPauv62icVmA/cqSYqSol83+lkxXAYDqNsa/DUgnlg1j7JLLIkcEDGF2/pRJSLeGU210+c0np/yAyeABVas/o24cSxpzhl9tWbnenM5dEkMEDqE6/un2MHbWdauwRz04dBhk8gMlpz347LYbWK8uvKuNv306nbF1aa1vCWXsZGTyA6Rlk2YRywJV61/R73W9tq9E49syhfD/hGns/BHgAYQ0yHLNbiafXfalzEG//GckFsifB3H3qbzo/P+/Ly8tTf18AiRs1g88kMzez/e4+P/DrCfAAkIZhAzydrACQKQI8AGSKAA8AmSLAA0CmCPAAkCkCPABkigAPAJkKMg7ezFYk/Xjqb9zZjKQIl7briLZWL5V2SrR1ElJpp1S0dZO7zw76C0ECfEzMbHmYiQMh0dbqpdJOibZOQirtlEZrKyUaAMgUAR4AMkWAl/aGbsAQaGv1UmmnRFsnIZV2SiO0tfY1eADIFRk8AGSKAA8Amap1gDez88zsXjO7z8wuD92ebsxsq5n9m5ndY2bfM7NLQ7epFzM73szuMrMvhm5LL2Z2ipndambfN7MDZvbS0G3qxsz+ovm3/66Z3WxmTw3dphYz+4SZPWxm3y099htmdpuZ/aD589dDtrHZpk7t3NP8+3/bzD5vZqeEbGNLp7aWnrvMzNzM+l7FpLYB3syOl3SDpD+S9BxJF5nZc8K2qqsnJF3m7s+R9BJJfxZxWyXpUkkHQjdiANdL+rK7/66k5yvSNpvZaZLeJmne3Z8r6XhJF4Zt1TqflHRe22OXS/qau58h6WvN+6F9Use28zZJz3X350n6T0lXTLtRXXxSx7ZVZrZV0qsk/WSQjdQ2wEt6saT73P2H7n5U0i2SLgjcpo7c/UF3v7N5+xEVgei0sK3qzMy2SHqNpI+HbksvZvY0SS+XdKMkuftRd/+/sK3q6QRJG83sBEknSfqvwO15krv/u6T/aXv4Akmfat7+lKTXTrVRHXRqp7t/xd2faN79hqQtU29YB10+U0n6gKTdkgYaHVPnAH+apPtL9w8p0qBZZmZzkrZL+mbYlnT1QRVfwF+Fbkgfp0takbTULCd93Mw2hW5UJ+7+gKT3qsjaHpT0M3f/SthW9fUMd3+wefshSc8I2ZgBXSLpn0M3ohszu0DSA+7+rUF/p84BPjlmtlnSZyW93d1/Hro97czsfEkPu/v+0G0ZwAmSzpT0UXffLumI4igjHKNZv75AxUHptyRtMrM3hm3V4LwYix31eGwzu0pFKfSm0G3pxMxOknSlpL8e5vfqHOAfkLS1dH9L87EomdmJKoL7Te7+udDt6eIsSQtmdlBFyesVZvaZsE3q6pCkQ+7eOhO6VUXAj9G5kn7k7ivu/rikz0n6g8Bt6ue/zew3Jan58+HA7enKzC6WdL6kP/F4JwY9S8UB/lvN/WuLpDvN7Jm9fqnOAf4OSWeY2elmtkFFp9W+wG3qyMxMRa34gLu/P3R7unH3K9x9i7vPqfg8/9Xdo8w03f0hSfeb2bObD50j6Z6ATerlJ5JeYmYnNb8L5yjSDuGSfZLe3Lz9ZklfCNiWrszsPBUlxQV3/0Xo9nTj7t9x96e7+1xz/zok6czm97ir2gb4ZsfKn0v6FxU7yz+6+/fCtqqrsyS9SUVGfHfz36tDNyoDb5V0k5l9W9ILJP1t4PZ01DzLuFXSnZK+o2K/jWaKvZndLOk/JD3bzA6Z2Q5J75H0SjP7gYozkPeEbKPUtZ0fkXSypNua+9XfB21kU5e2Dr+deM9IAADjqG0GDwC5I8ADQKYI8ACQKQI8AGSKAA8AmSLAA0CmCPAAkKn/B97UVv3RfknRAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "len_horizon = 10\n",
    "EMAX=3000\n",
    "kappa=0.1\n",
    "TMAX = 1\n",
    "#_controls = torch.nn.Parameter(torch.zeros(len_horizon,2))\n",
    "print(state0_)\n",
    "dt = torch.ones(len_horizon,1)\n",
    "#controls_ = torch.nn.Parameter(torch.zeros(len_horizon,2))\n",
    "dt = torch.ones(len_horizon,1)\n",
    "\n",
    "vs_ = torch.nn.Parameter(torch.ones(len_horizon)*0.8)\n",
    "\n",
    "print(controls_)\n",
    "\n",
    "#vs = torch.nn.Parameter(torch.ones(len_horizon))\n",
    "print(vs_)\n",
    "opt = torch.optim.Adam([vs_]+[controls_],lr=0.001)\n",
    "#print(controls.size())\n",
    "for T in range(TMAX):\n",
    "    for epoch in range(EMAX):\n",
    "        controls = torch.cat([controls_,dt],dim=1)\n",
    "        s=state0_\n",
    "        s__=[]\n",
    "        for t in range(len_horizon):\n",
    "            #print(controls.size())\n",
    "            s = model(s,controls[t])\n",
    "            s__.append(s.view(1,-1))\n",
    "        s__ = torch.cat(s__,dim=0)\n",
    "        if epoch==0:\n",
    "            s0_ = s__.data.clone()\n",
    "        dev = deviation_error_.apply\n",
    "        dev_loss = dev(s__,vs_,waypoints,start_)\n",
    "        v_loss = -kappa*vs_.sum()\n",
    "        loss =  dev_loss #v_loss + dev_loss \n",
    "        if epoch % 100 ==0:\n",
    "            print(dev_loss.data.numpy(),v_loss.data.numpy())\n",
    "        opt.zero_grad()\n",
    "        loss.backward(retain_graph=True)\n",
    "        opt.step()\n",
    "x_ = s__[:,0].data.numpy()\n",
    "y_ = s__[:,1].data.numpy()\n",
    "    \n",
    "plt.scatter(x_,y_,c='b',s=20)\n",
    "plt.scatter(cx,cy,c='r',s=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.collections.PathCollection at 0x7f59a0490f98>"
      ]
     },
     "execution_count": 97,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAD4CAYAAADmWv3KAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAYqElEQVR4nO3df4xlZX3H8ffXXZBhBdTOgMoyDrZoQ2hl9WJVbEsFDdbtrjbGgNXguMlEUwEt7hYkdUOaNnTXn1Vjs3EZSMRFg1gnDVUQpcakArOI8mOlUl1gV3DvaEVZkR/x2z/Ovc7dy71zf5zn3uc553xeyWbu3B/nPjsz93Oe8z3PeR5zd0REpHyeEbsBIiIyGgp4EZGSUsCLiJSUAl5EpKQU8CIiJbU6xptOTk76zMxMjLcWESms3bt3L7n7VL/PjxLwMzMzLC4uxnhrEZHCMrP7B3m+SjQiIiWlgBcRKSkFvIhISSngRURKSgEvIlJSCniRwOp1uO227KtITAp4kYB27YIXvhBe97rs665dsVskVaaAFwmkXodNm+Cxx+CRR7Kvmzal2ZPXUUY1KOBFAtm7Fw4//ND7Djssuz8lOsqoDgW8SCAzM/DEE4fe9+ST2f2pKNJRhuSngBcJZGoKdu6EiQk4+ujs686d2f2pKMpRhoQRZS4akbI691w466wsMGdm0gp3KMZRhoSjHrxIYFNTcNpp6YU7FOMoQ8LpuwdvZlcA64ED7n5K477nAl8AZoC9wFvd/f/CN1NEQkn9KEPCGaQHfyVwdtt9FwM3uftJwE2N70XGTsP+BpPyUYaE03fAu/u3gJ+33b0RuKpx+yrgTYHaJdI3DfsT6SxvDf44d3+ocfth4Lic2xMZiIb9iXQX7CSruzvg3R43szkzWzSzxbo+fRKIhv2JdJc34H9qZs8HaHw90O2J7r7D3WvuXptS4U8C0bA/ke7yBvwCcF7j9nnAV3JuT2QgGvYn0t0gwyR3AWcAk2a2D9gKXA580cw2AfcDbx1FI0VWomF/Ip31HfDufm6Xh84M1BaRoU1NKdiLrF7XDnoUdCWriESlYa6jo4AXkWg0zHW0FPAiEo2GuY6WAl5EotEw19FSwItINBrmOlqaD15EotIw19FRwItIdBrmOhoq0cjYaEpfkfFSwMtYaKyzyPgp4GXkNNZZJA4FvIycxjqLxKGAl5HTWGeROBTwMnIa6ywSh4ZJylhorLPI+CngZWw01llkvFSiEREpqSABb2bvN7O7zewuM9tlZkeE2K6IiAwvd8Cb2fHABUDN3U8BVgHn5N2uiIjkE6pEsxqYMLPVwJHATwJtV0REhpQ74N19P/Bh4AHgIeARd7+h/XlmNmdmi2a2WNcljCIiIxeiRPMcYCNwIvACYI2Zvb39ee6+w91r7l6b0lAKEZGRC1GiOQv4sbvX3f1J4Drg1QG2KyIiOYQI+AeAV5rZkWZmwJnAngDbFRGRHELU4G8BrgVuB+5sbHNH3u2KiEg+Qa5kdfetwNYQ25Jiqdc1/YBIqnQlqwxNi3iIpE0BL0PRIh4i6VPAy1C0iIdI+hTwMhQt4iGSPgW8DEWLeIikT/PBy9C0iIdI2hTwkosW8RBJl0o0IiIlpYAXESkpBbyIjM/SEmzfDvfem31dWlq+b2kpdutKRzV4ERmdpSWYn4cNG2BhAQ4ehMsug5tvhuuvX37eli3ZY2vWLD93dhYmJ6M1vQwU8CISVjPUZ2ezr1u2LAf61q2wbVsW4meckT2n6eDBQ58Ly9tQ2A9FAS8iYTSDvdlLh+UAbw30ZlBv3rz82s2bs9c3e/DN5zZ3EM3evYJ+IObuY3/TWq3mi4uLY39fERmh7duzMN66NVwYt+80Qm67gMxst7vX+n2+evAikk9rnR3Chu/k5KG9+2YZBw49ApCOFPAiMpxOJZlRhW570G/YkB0xVLQn368gwyTN7Nlmdq2Z/cDM9pjZq0JsV9JXr8Ntt2ma4Epq1schO3HaesJ0VJpBv7CQvff8vIZZriBUD/4TwFfd/S1mdjhwZKDtSsJ27crmgD/88GxmyZ07s/lppORGWZLpV3Nn0noiFlS2aZP7JKuZHQPcAbzI+9yYTrIWX72ereL02GPL901MwP33a26a0mueTN22LY1AbR9rX+KyTYyTrCcCdWDezF4K7AYudPeDbQ2bA+YApqenA7ytxNRc8KM14JsLfijgS6pTzz0FzbJNc8cDaex4EhCiBr8aeBnwGXdfBxwELm5/krvvcPeau9emlACFpwU/KqhZCllYyAI0tV7y7OzyRVSqyQNhAn4fsM/db2l8fy1Z4EuJacGPCmoGaCo993adTsBWXO4Sjbs/bGYPmtlL3P1e4EzgnvxNk9RpwY8KaJ12oBmgqWu9erbiQylDjaI5H7i6MYLmR0Ciu3gJTQt+lFwRR6ioJv87QQLe3e8A+j6zKyKJS/WE6iDUk9eVrCLSQRF77u3Uk1fAi0iLMvTc27VeFFUxWtFJRJalPhRyGK0nhys2fFI9eBFZVubebhnKTgNSD74MWidb6ue2SLvm3weUp+feroIXQqkHXyStY5Lh6cuiNfW63Zx2tdO2RvDBrtc1Vj55VejdVvCkqwI+dZ3Wt2xq3u50WN3rdrdttYd/zsDXjJOJK+NJ1V7KXIZq5+5j//fyl7/cZQX1uvu2bctf4dDv6/VDb+d9j/Zttb5n+3MHcOCA+8REtqnmv4mJ7H5JRPvvukryfoYiABZ9gKxVwKein1CP0Rb3oQP/1lvdjznm0IA/+ujsfklEAUMumALu3AYNeJVoYltpJfpY83+0v2f7IW2f9VrNOJmoIs4vMwoVKNVoFE0MrSNaOi171vzQpTKSob09rbMKrjA6RzNOJqr5N1f12RYrMD5ePfgYWkO9vbdeBO0fjBVO0GrGyYRU8YRqP0o8gkgBP07d1rIs8h9Vr/LN0hJTV84zVaQdWFmVOMhyKXGpRgE/at2GOZblAxaoXi9jUOIgy6X5N9wsN5aoM6KAH7Vu5Ziy6hX4AcfYS59af+bayXZXws5IsIA3s1XAIrDf3deH2m5hlbEcM4z2/3OnD5FCf7RKGFwjUcIOWMge/IXAHuDogNssLn2oOuv0Ierys9IUBzm07jRLGFwjUcJSTZCAN7O1wBuBfwL+LsQ2C0sjFVbW6SimQxnnuxfMs+HLs/zqmZOa4mAY7TtNdTL6V6LOWage/MeBLcBR3Z5gZnPAHMD09HSgt01Qif44xqYt9B/95Dzrdm3hHODDv8nu37QpG26pnnyf1GsfXol+drkvdDKz9cABd9+90vPcfYe719y9NlXGT2nzsG7DhuWLgGQoP3zNLP9wxDbmG2u3/x5LXPTb7ey7o3wXogTVetFZahfLFUmJLoAKcSXr6cAGM9sLXAO81sw+F2C7xVLGlXAiWXvqJB+xzfyM7Gc4yzz/+PgWTvp248pLzW3fma5QDasEP8/cJRp3vwS4BMDMzgA+4O5vz7vdQtCJrJFoTnGwaRMcdhhc88Qsb3szrDtfY+s70nmf0SjBZ1rj4PPQiayROXSKg0mmpjS2vivt8EajBKNqgga8u98M3Bxym0lSj2kspqa6nFTtNba+CoGvo8fxKfAOVD34YRT4F15KXaZHePRR2LN+cznH0evocXwKvAO1bA758arVar64uDj29w2mCj3EImsbR3/U40ssvHmedf9a8N/XGNfRlQ4S+Nyb2W53r/X7fM0HP4gqrDxfAnWf5PR/38y+30zyyCNwzm+ycfWPfrKAo3A6rR0wP69hkDEUcFSNSjSDUGmmEPbuzRb5fuyx7Pt5ZjnimfDXr5llHRSrZl+1yepSVsCfvwK+HzqpWijtSwX+jEk+8ozNXHBq444+5rCPGvjdTqBWcbK6lBTx5z/IAq6h/hVu0e0CLs5bdZ//vPvERLbI98RE9n1XgRYZz6XbouuSlsiLlDPgotsK+H5UeeX5AjtwwP3WW7OvAxk08Ff6vt/HWt9Df2/pirzzHTTgVaLpJfbhugyt6zj6XgZdpWql76G/x1SKKYaC1eEV8L3oxKr0CvxeX/t5TKFeDAW7ulXj4LtpPbG6sJD8L1JExmj79qzjt23bWHfMg46DVw++G/XcRaSbgpRqFPDdFOQXKCIRFKSkpitZ2+lqVRHpRwGuiFbAtyvg5cgiEkEBskIlmnYqzYhIPwqQFerBN6k0IyKDKMDarSEW3T7BzL5pZveY2d1mdmGIho1dAQ63RCRBCWdHiBLNU8BF7n67mR0F7DazG939ngDbHp8CHG6JSIISzo7cPXh3f8jdb2/c/hWwBzg+73bHRqUZEckj4bn5g9bgzWwGWAfc0uGxOTNbNLPFer0e8m3zSfjwSkQKItEhk8FG0ZjZs4AvAe9z91+2P+7uO4AdkE1VEOp9c0v48EpECiLRK9+DBLyZHUYW7le7+3UhtjkWmilSREJItKMYYhSNATuBPe7+0fxNGiOVZ0QkhMnJLNzn55Mq04TowZ8OvAO408zuaNz3QXe/PsC2RyvRva6IFFCCZZrcAe/u3wYsQFvGp7U0k8gvQkQKLsEOYzWvZFVpRkRCS3C4ZDXnoklwTysiJZHQ4I3q9eAT+uGLSAklVCGoXg8+wRMhIlIiCVUIqhfwCf3wRaSEElrtqTolGs05IyLjksjUBdUJ+ITqYiJSconkTXVKNCrNiMi4JJI35e/BqzQjIuOWyNQF5Q/4RA6VRKRiEsie8pdoEjlUEpGKSSB7zH38U7PXajVfXFwc+/uKiBSZme1291q/zy93iSaRoUoiUlGRM6jcAZ9ADUxEKixyBpW7Bp9ADUxEKixyBgXpwZvZ2WZ2r5ndZ2YXh9hmLhoaKSIpiDyFcIgl+1YBnwbeAJwMnGtmJ+fdbi4qzYhIKiLW4UOUaF4B3OfuPwIws2uAjcA9AbY9nEEOizR9sIiMUsQZbEME/PHAgy3f7wP+pP1JZjYHzAFMT08HeNsuBg3s1h9+88ozhb2IhBKxDj+2UTTuvsPda+5em5qaGt0bDVqemZ2FbduWw731tRpmKSIFFqIHvx84oeX7tY374hh0b9k6d3P7a9W7F5G8Cl6iuQ04ycxOJAv2c4C3BdjucPJMtt/+2tbAb/8lqXYvIv2IWKLJHfDu/pSZvRf4GrAKuMLd787dsmGEDl317kUkr4grPAW50MndrweuD7GtXEZ5KKTevYgMK1ImlOtK1nEeCql3LyL9ilSHL0fAt+4dYxwKqXcvIiuJVIcvR8BHPEvdkXr3ItIqUh2+HAGf8qRi6t2LiGrwOUQ8Sz0w9e5Fqkc1+CEVuder3r1INUSqMhR/wY8yzRzZOrVo6xQKoGkURIqs+Zmenx/rZ7b4PfiU6+95rNS7B5VzRIomQpmmuAEfe2jkuKmcI1JsETqjxQ341IZGjlu/J2sV+CJpiDAYpLgBX9bSzDBUzhFJX4SOVnEDvkhDI8dN5RyR9KgG3yeF0mBUzhGJTzX4PlW9/p7HIOUcBb5IoRUz4FV/D0eBLzIeKtH0ULWhkTHohK3IaBStRGNm24G/Ap4A/heYdfdfhGhYRyrNjN8gJ2xFJCl5e/A3Apc0lu37F+AS4O/zN6sLlWbiW+mErco3It0VrUTj7je0fPsd4C35mtODhkampf33oXq9SHdFK9G0eRfwhW4PmtkcMAcwPT0d8G0lGTpBK9JdhA5qz9kkzezrZnZXh38bW55zKfAUcHW37bj7DnevuXttamoqTOslLa2zYYJmxBRpFeHvvWcP3t3PWulxM3snsB440909ULukDDQEU2RZ0WrwZnY2sAX4c3f/dZgmSWkNGvgiZVLAGvyngGcCN5oZwHfc/d25WyXV0CvwQb16kRzyjqL5g1ANEel4EkplHCmLopVoREZOdXspiwKWaERGS3V7kaEp4KVYegW+evSSKpVoRAbU62paUOhLGlSiEcmp04dIZRyJLVInQwEv5dJpJI7KOBJbpE6GAl7KT5OiSWyRZsJVwEv1aCSOVIQCXqpHV9DKuKlEIxKJrqCVUVOJRiQhKuNIKBE7Bwp4kU5UxpFQInYOFPAi/einjAMKfXm6iGtJK+BFhqWLqqSXyDt8BbzIsHRRlfQSeYffc03WfpjZRWbmZqa/YKm29nVp29ehBa1FWxVLS3DwIGzdGqU8AwEC3sxOAF4PPJC/OSIl077wOHQOfSmf+Xm47DJYsyba0VuIEs3HyNZl/UqAbYmUSz9lHFApp4winlxtytWDN7ONwH53/16g9oiUX3sZB57eq1cZp9gS2WH37MGb2deB53V46FLgg2TlmZ7MbA6YA5ienh6giSIVoAuryiWR35+5+3AvNPsj4Cbg14271gI/AV7h7g+v9NpareaLi4tDva9IJXTqASbSK5QVNH9HGzbAwkLw35WZ7Xb3Wr/PH7oG7+53Ase2vPFeoObuOqYUyUsXVhVTIj33Jo2DFykKXViVtgSGRbYLMg4ewN1n1HsXGaFOJ2fbh2Hq5Gw8CQyLbKcevEiRadHx+Frr7pBM7x0U8CLl0m8ZR6EfTsJlMgW8SJn0e2GVFjTJL+Gee5MCXqTs+gn9hHuhySrAz0wBL1JFWtBkeAXouTcp4EVE69L2o/n/P3gwGy0Dyfbcm4INkxSRkmkfgln1+XJad3jtM4QmSj14EemsVxmnKqNzOpVkCvJ/U8CLSH/6qdv3KusUZQfQ2s4CnEztRgEvIsMZZnROanX91vdvtrc91BOY131YCngRCadXL7/XDqBXj3/YHUI/QQ6dQ73TjqwgFPAiMjrt4ThoXX+QI4Dm6zvtDPoJ8tY2FTjUWyngRSSeQXv8K+0QoPvOoN8gL0Gotxp6wY88tOCHiATRbw8+5RO6Axh0wQ8FvIhIQQwa8LrQSUSkpHIHvJmdb2Y/MLO7zWxbiEaJiEh+uU6ymtlfABuBl7r742Z2bK/XiIjIeOTtwb8HuNzdHwdw9wP5myQiIiHkDfgXA39qZreY2X+Z2Wndnmhmc2a2aGaL9Xo959uKiEgvPUs0ZvZ14HkdHrq08frnAq8ETgO+aGYv8g5Dc9x9B7ADslE0eRotIiK99Qx4dz+r22Nm9h7gukag32pmvwUmAXXRRUQiyzUO3szeDbzA3T9kZi8GbgKmO/Xg215XB+4f+o3DmwSKMKl1UdoJxWlrUdoJausoFKWdkLV1jbtP9fuCvAF/OHAFcCrwBPABd//G0BuMxMwWB7l4IJaitBOK09aitBPU1lEoSjthuLbmGibp7k8Ab8+zDRERGQ1dySoiUlIK+MyO2A3oU1HaCcVpa1HaCWrrKBSlnTBEW6NMNiYiIqOnHryISEkp4EVESqrSAW9mZ5vZvWZ2n5ldHLs93ZjZCWb2TTO7pzFr54Wx27QSM1tlZt81s/+I3ZaVmNmzzezaxmyoe8zsVbHb1I2Zvb/xu7/LzHaZ2RGx2wRgZleY2QEzu6vlvuea2Y1m9sPG1+fEbGNTl7Zub/z+v29mXzazZ8dsY1OntrY8dpGZuZn1XMWksgFvZquATwNvAE4GzjWzk+O2qqungIvc/WSyaSH+NuG2AlwI7IndiD58Aviqu/8h8FISbbOZHQ9cANTc/RRgFXBO3Fb9zpXA2W33XQzc5O4nkV38mErn6Uqe3tYbgVPc/Y+B/wEuGXejuriSp7cVMzsBeD3wQD8bqWzAA68A7nP3HzXG819DNvVxctz9IXe/vXH7V2RBdHzcVnVmZmuBNwKfjd2WlZjZMcCfATshu6bD3X8Rt1UrWg1MmNlq4EjgJ5HbA4C7fwv4edvdG4GrGrevAt401kZ10amt7n6Duz/V+PY7wNqxN6yDLj9XgI8BW4C+RsdUOeCPBx5s+X4fiYZmKzObAdYBt8RtSVcfJ/sD/G3shvRwItmcSfONctJnzWxN7EZ14u77gQ+T9doeAh5x9xvitmpFx7n7Q43bDwPHxWzMAN4F/GfsRnRjZhuB/e7+vX5fU+WALxwzexbwJeB97v7L2O1pZ2brgQPuvjt2W/qwGngZ8Bl3XwccJJ1SwiEaNeyNZDulFwBrzKwQV5A35qVKfiy2mV1KVgq9OnZbOjGzI4EPAh8a5HVVDvj9wAkt369t3JckMzuMLNyvdvfrYreni9OBDWa2l6zk9Voz+1zcJnW1D9jn7s0joWvJAj9FZwE/dve6uz8JXAe8OnKbVvJTM3s+QONr0gsBmdk7gfXA3/SaKDGi3yfbwX+v8flaC9xuZp2mcv+dKgf8bcBJZnZiY9K0c4CFyG3qyMyMrFa8x90/Grs93bj7Je6+1t1nyH6e33D3JHua7v4w8KCZvaRx15nAPRGbtJIHgFea2ZGNv4UzSfSEcMMCcF7j9nnAVyK2ZUVmdjZZSXGDu/86dnu6cfc73f1Yd59pfL72AS9r/B13VdmAb5xYeS/wNbIPyxfd/e64rerqdOAdZD3iOxr//jJ2o0rgfOBqM/s+2Yyo/xy5PR01jjKuBW4H7iT73CZxib2Z7QL+G3iJme0zs03A5cDrzOyHZEcfl8dsY1OXtn4KOAq4sfG5+reojWzo0tbBt5PuEYmIiORR2R68iEjZKeBFREpKAS8iUlIKeBGRklLAi4iUlAJeRKSkFPAiIiX1/wn1MOUNX3cmAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "x_ = s0_[:,0].data.numpy()\n",
    "y_ = s0_[:,1].data.numpy()\n",
    "    \n",
    "plt.scatter(x_,y_,c='b',s=20)\n",
    "plt.scatter(cx,cy,c='r',s=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(153) tensor(153.8121, grad_fn=<AddBackward0>)\n",
      "tensor(154) tensor(154.6242, grad_fn=<AddBackward0>)\n",
      "tensor(155) tensor(155.5035, grad_fn=<AddBackward0>)\n",
      "tensor(157) tensor(156.4247, grad_fn=<AddBackward0>)\n",
      "tensor(158) tensor(157.3033, grad_fn=<AddBackward0>)\n",
      "tensor(160) tensor(158.4999, grad_fn=<AddBackward0>)\n",
      "tensor(162) tensor(159.8758, grad_fn=<AddBackward0>)\n",
      "tensor(165) tensor(161.5962, grad_fn=<AddBackward0>)\n",
      "tensor(168) tensor(163.5596, grad_fn=<AddBackward0>)\n",
      "tensor(171) tensor(165.9267, grad_fn=<AddBackward0>)\n"
     ]
    }
   ],
   "source": [
    "for i, s__ in enumerate(s_):\n",
    "    print(search__(s__,[waypoints[:,0],waypoints[:,1],_,_],_indx=start_) ,vs[:i+1].sum()+start_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "def search__(state,waypoints,_indx):\n",
    "    x,y = state[0],state[1]\n",
    "    \n",
    "    cx,cy,_,_ = waypoints\n",
    "    dx = x-cx[_indx:_indx+L]\n",
    "    dy = y-cy[_indx:_indx+L]\n",
    "    d2 = dx.pow(2) + dy.pow(2)\n",
    "    indx = torch.argmin(d2) + _indx\n",
    "\n",
    "    return indx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class mpc:\n",
    "    def __init__(self,model,evals,optimizer,len_horizon,waypoints):\n",
    "        self.model = model\n",
    "        self.evals = evals\n",
    "        self.optimizer = optimizer\n",
    "        self.len_horizon = len_horizon\n",
    "        self.waypoints = waypoints\n",
    "        \n",
    "    def \n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class grad_test(torch.autograd.Function):\n",
    "    \n",
    "    @staticmethod\n",
    "    def forward(ctx,x):\n",
    "        ctx.save_for_backward(x)\n",
    "        return x*x\n",
    "    \n",
    "    @staticmethod\n",
    "    def backward(ctx,de):\n",
    "        x = ctx.saved_tensors[0]\n",
    "        print(x)\n",
    "        return 2*x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x=torch.Tensor([5.]).requires_grad_()\n",
    "g = grad_test.apply\n",
    "y = g(x)\n",
    "print(y)\n",
    "y.backward()\n",
    "x.grad.data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def error_deviation_parallel(states,waypoints,indices,ql,qc):\n",
    "    x,y = states[:,0],states[:,1]\n",
    "    cx,cy,gx,gy = waypoints[indices,0],waypoints[indices,1],waypoints[indices,2],waypoints[indices,3]\n",
    "    el = -gx*(x-cx) - gy*(y-cy)\n",
    "    ec =  gy*(x-cx) - gx*(y-cy)\n",
    "    err = ql*el.pow(2).sum()+qc*ec.pow(2).sum()\n",
    "    return err"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def error_deviation_parallel_(states,refs,indices,ql,qc):\n",
    "    x,y = states[:,0],states[:,1]\n",
    "    cx,cy,gx,gy = refs[:,0],refs[:,1],refs[:,2],refs[:,3]\n",
    "    el = -gx*(x-cx) - gy*(y-cy)\n",
    "    ec =  gy*(x-cx) - gx*(y-cy)\n",
    "    err = ql*el.pow(2).sum()+qc*ec.pow(2).sum()\n",
    "    return err"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "ql=1\n",
    "qc=1\n",
    "\n",
    "class deviation_error(torch.autograd.Function):\n",
    "    \n",
    "    @staticmethod\n",
    "    def forward(ctx,states,vs,waypoints):\n",
    "        states = states.data.clone()\n",
    "        states.requires_grad = True\n",
    "        \n",
    "        inds = vs.round().long()\n",
    "        \n",
    "        dw = waypoints[inds+1]-waypoints[inds]\n",
    "        \n",
    "        refs = waypoints[inds].data.clone().requires_grad_()\n",
    "        with torch.enable_grad():\n",
    "            error = error_deviation_parallel_(states,refs,inds,ql,qc)\n",
    "            de_s,de_w = torch.autograd.grad(error,[states]+[refs],grad_outputs=None,retain_graph=False,create_graph=False)\n",
    " \n",
    "        ctx.save_for_backward(states.data.clone(),vs.data.clone(),waypoints,de_s.data.clone(),de_w.data.clone(),dw.data.clone())\n",
    "        \n",
    "        return error\n",
    "    \n",
    "    @staticmethod\n",
    "    def backward(ctx,de):\n",
    "        states,vs,waypoints,de_s,de_w,dw = ctx.saved_tensors\n",
    "          \n",
    "        de_theta = (de_w@dw.t()).diag().view(-1,1)\n",
    "        \n",
    "        _range = torch.arange(len(states)).view(-1,1)\n",
    "        range_ = _range.t()\n",
    "        mask = (_range<=range_).float()\n",
    "        de_v = (mask@de_theta).view(-1)\n",
    "   \n",
    "        \n",
    "        return de_s,de_v,None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([7.1618, 2.9268, 3.1397, 0.1686, 0.0000, 0.0000]),\n",
       " tensor([ 6.3231,  2.9579,  2.6732,  0.8252, -0.5831,  0.1357],\n",
       "        grad_fn=<SelectBackward>))"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "state0,state0_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "ql=1\n",
    "qc=1\n",
    "\n",
    "class deviation_error(torch.autograd.Function):\n",
    "    \n",
    "    @staticmethod\n",
    "    def forward(ctx,states,vs,waypoints):\n",
    "        states = states.data.clone()\n",
    "        states.requires_grad = True\n",
    "        \n",
    "        inds = vs.round().long()\n",
    "        \n",
    "        dw = waypoints[inds+1]-waypoints[inds]\n",
    "        print(states.size())\n",
    "       \n",
    "        with torch.enable_grad():\n",
    "            error = error_deviation_parallel(states,waypoints,inds,ql,qc)\n",
    "            de_s = torch.autograd.grad(error,states,grad_outputs=None,retain_graph=False,create_graph=False)[0]\n",
    "        print(de_s.size(),states.size())\n",
    "              \n",
    "        de_w = -de_s\n",
    "        ctx.save_for_backward(states.data.clone(),vs.data.clone(),waypoints,de_s.data.clone(),de_w.data.clone(),dw.data.clone())\n",
    "        print(states,vs.size(),waypoints.size(),de_s.size(),de_w.size(),dw.size())\n",
    "        \n",
    "        return error\n",
    "    \n",
    "    @staticmethod\n",
    "    def backward(ctx,de):\n",
    "        states,vs,waypoints,de_s,de_w,dw = ctx.saved_tensors\n",
    "        de_w = torch.zeros(len(vs),waypoints.size()[1])\n",
    "        \n",
    "        de_w[:,:2] = -de_s[:,:2]\n",
    "        \n",
    "        de_theta = (de_w@dw.t()).diag().view(-1,1)\n",
    "        \n",
    "        _range = torch.range(0,len(states)-1).view(-1,1)\n",
    "        range_ = _range.t()\n",
    "        mask = (_range<=range_).float()\n",
    "        de_v = (mask@de_theta).view(-1)\n",
    "        print(\"########\")\n",
    "        print(de_s)\n",
    "        print(\"########\")\n",
    "        print(de_w)\n",
    "        print(\"########\")\n",
    "        print(de_v)\n",
    "        print(\"########\")\n",
    "        \n",
    "        return de_s,de_v,None\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "states=torch.ones(5,3).requires_grad_()\n",
    "vs = torch.Tensor([1.1,2.1,3.3,4.7,10.7]).requires_grad_()\n",
    "\n",
    "\n",
    "opt = torch.optim.Adam([states]+[vs],lr=0.1)\n",
    "\n",
    "waypoints = (torch.range(0,100)+0.4).view(-1,1).repeat(1,4)\n",
    "print( waypoints.size() )\n",
    "\n",
    "de = deviation_error.apply\n",
    "err = de(states,vs,waypoints)\n",
    "print(err)\n",
    "opt.zero_grad()\n",
    "err.backward()\n",
    "print(vs.grad.data,states.grad.data)\n",
    "\n",
    "_range = torch.range(0,len(states)).view(-1,1)\n",
    "range_ = _range.t()\n",
    "mask = (_range<=range_).float()\n",
    "mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.range(0,5).view(-1,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "states = torch.nn.Parameter(torch.ones(10,3))\n",
    "waypoints = torch.zeros(100,4)\n",
    "v= torch.Tensor(np.arange(10)).float()\n",
    "inds= v.round().long()\n",
    "\n",
    "#print(v)\n",
    "#print(v+1)\n",
    "#inds = [*range(10)]\n",
    "\n",
    "dw = waypoints[inds+1,:2]-waypoints[inds,:2]\n",
    "print(dw)\n",
    "error = error_deviation_parallel(states,waypoints,indices,ql=1,qc=1)\n",
    "\n",
    "\n",
    "de_s = torch.autograd.grad(error,states,grad_outputs=None,retain_graph=False,create_graph=False)\n",
    "de_s[0].size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "loss = inds.sum()\n",
    "opt=torch.optim.SGD([vs],lr=0.1)\n",
    "opt.zero_grad()\n",
    "print(vs.grad)\n",
    "loss.backward()\n",
    "print(vs.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "waypoints[0]\n",
    "hl=20\n",
    "\n",
    "das = torch.nn.Parameter(torch.zeros(hl))\n",
    "dds = torch.nn.Parameter(torch.zeros(hl))\n",
    "\n",
    "controls = torch.nn.Parameter(torch.zeros(hl,3))\n",
    "controls[:,2]+=0.01\n",
    "\n",
    "vs = torch.nn.Parameter(torch.arange(1,hl).float()) \n",
    "inds = vs.round()\n",
    "state = torch.Tensor([ 0,0,-0.48,0.01,0.0,0.01 ])\n",
    "\n",
    "#_states=[state.view(1,-1)]\n",
    "_states = []\n",
    "\n",
    "for t in range(hl-1): \n",
    "    state = forward(state,controls[t])\n",
    "    _states.append(state.view(1,-1))\n",
    "\n",
    "\n",
    "states = torch.cat(_states,dim=0)\n",
    "\n",
    "states.size(),vs.size()\n",
    "inds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def search_(state,waypoints,_indx):\n",
    "    x,y,_,_,_,_ = state\n",
    "    cx,cy = waypoints[:,0],waypoints[:,1]\n",
    "    dx = x-cx[_indx:_indx+L]\n",
    "    dy = y-cy[_indx:_indx+L]\n",
    "    d2 = dx.pow(2) + dy.pow(2)\n",
    "    indx = torch.argmin(d2) + _indx\n",
    "\n",
    "    return indx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "states[16]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "search_parallel(states[:16],waypoints,0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "L=100\n",
    "def search_parallel(states,waypoints,_indx):\n",
    "    xs,ys = states[:,0].view(-1,1),states[:,1].view(-1,1)\n",
    "    cx,cy = waypoints[:,0].view(1,-1),waypoints[:,1].view(1,-1)\n",
    "    dx = xs - cx[:,_indx:_indx+L]\n",
    "    dy = ys - cy[:,_indx:_indx+L]\n",
    "    d2 = dx.pow(2) + dy.pow(2)\n",
    "    inds = torch.argmin(d2,dim=1) + _indx\n",
    "    return inds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "state=torch.Tensor([1,1,1,1,1,1])\n",
    "control=torch.Tensor([1,1,1])\n",
    "forward(state,control)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d=torch.ones(2)\n",
    "torch.max(d,2*torch.ones_like(d))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "THRESHOLD = -0.01\n",
    "learning_rate = 0.001\n",
    "EMAX=100000\n",
    "\n",
    "class constrained_optimization:\n",
    "    def __init__(self,obj,cons,myu):\n",
    "        self.obj = obj\n",
    "        self.cons = cons\n",
    "        self.myu = myu\n",
    "    \n",
    "    def compute_penalty(self,x):\n",
    "        penalty = 0\n",
    "        for con in self.cons:\n",
    "            c = torch.max(con(x),THRESHOLD*torch.ones_like(con(x)))\n",
    "            #if c>THRESHOLD:\n",
    "            penalty += torch.log(-c).sum()\n",
    "        return penalty\n",
    "    \n",
    "    def compute_cost(self,x):\n",
    "        cost = self.obj(x) - self.myu*self.compute_penalty(x)\n",
    "        return cost\n",
    "    \n",
    "    def judge_bleaching(self,x):\n",
    "        judge = torch.prod(torch.cat([torch.prod(c(x)<=0).view(1) for c in self.cons]) )\n",
    "        return judge\n",
    "\n",
    "    def run(self,x0):\n",
    "        c=0\n",
    "        x = torch.nn.Parameter(x0)\n",
    "        opt=torch.optim.Adam([x],lr=learning_rate)\n",
    "        _x = x.data.clone()\n",
    "        for epoch in range(EMAX):\n",
    "       \n",
    "            cost = self.compute_cost(x)\n",
    "            opt.zero_grad()\n",
    "            cost.backward()\n",
    "            opt.step()\n",
    "            \n",
    "            if self.judge_bleaching(x):\n",
    "                _x = x.data.clone()\n",
    "            else:\n",
    "                return _x\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def objective(x):\n",
    "    return x.pow(2).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def const1(x):\n",
    "    return x-5\n",
    "def const2(x):\n",
    "    return (x-1).pow(2).sum()-1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "co = constrained_optimization(objective,[const2],myu=0.0001)\n",
    "x0 = torch.FloatTensor([1.2,1.2])\n",
    "x=co.run(x0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#x=torch.Tensor([0,0])\n",
    "(x-1).pow(2).sum()-1\n",
    "(x-1.01).pow(2).sum()-1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x-5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m=1500\n",
    "I=2500\n",
    "lf=1.1\n",
    "lr=1.6\n",
    "Kf=55\n",
    "Kr=60"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward(state,control):\n",
    "    x,y,phi,vx,vy,r,delta,T = state\n",
    "    d_delta,dt = control\n",
    "    \n",
    "    cos_p = torch.cos(phi)\n",
    "    sin_p = torch.sin(phi)\n",
    "    cos_d = torch.cos(delta)\n",
    "    sin_d = torch.sin(delta)\n",
    "    \n",
    "    alpha_r = torch.atan((vy-lr*r)/vx )\n",
    "    alpha_f = torch.atan((vy-lf*r)/vx - delta )\n",
    "    \n",
    "    Fry = Dr*torch.sin( Cr*torch.atan(B*alpha_r) )\n",
    "    Fry = Df*torch.sin( Cf*torch.atan(B*alpha_f) )\n",
    "    Fx = Cm2*T - Cr0 - Cr2*vx*vx #those parameters to be determin\n",
    "    r_target = delta*vx/(lf+lr)\n",
    "    tau = (r_target-r)*Ptv\n",
    "    \n",
    "    v_2 = vx.pow(2) + vy.pow(2)\n",
    "    v = torch.(sqrt(v_2))\n",
    "    \n",
    "    d_beta = -r - 2*(Kf*(beta+lr/v*r-delta)+Kr*(beta-lr/v*r) )/m/v\n",
    "    d_r = -2*( Kf*(beta+lf/v*r-delta)*lf + Kr*(beta-lr/V*r)*lr )/I\n",
    "    \n",
    "    dx = vx*cos_p - vy*sin_p\n",
    "    dy = vx*sin_p + vy*cos_p\n",
    "    d_phi = r\n",
    "    dvx = (Fx - Ffy*sin_d + m*vy*r)/m\n",
    "    dvy = (Fry - Ffy*cos_d - m*vx*r)/m\n",
    "    dr = (Ffy*lf*cos_d - Fry*lr + tau)/lz\n",
    "    \n",
    "    x_     = x + dx*dt\n",
    "    y_     = y + dy*dt\n",
    "    phi_   = phi + d_phi*dt\n",
    "    vx_    = vx + dvx*dt\n",
    "    vy_    = vy + dvy*dt\n",
    "    r_     = r + dr*dt\n",
    "    delta_ = delta + d_delta\n",
    "    T_     = T + dt\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a=np.zeros(10)\n",
    "a.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def construct_sigma(df,dg,sigma_x,sigma_w,B):\n",
    "    \n",
    "    sigma_00 = sigma_x\n",
    "    sigma_10 = dg@sigma_x\n",
    "    sigma_01 = sigma_10.t()\n",
    "    sigma_11 = sigma_10@dg.t() + sigma_w\n",
    "    \n",
    "    sigma_0 = torch.cat([sigma_00,sigma_01],dim=1)\n",
    "    sigma_1 = torch.cat([sigma_10,sigma_11],dim=1)\n",
    "    \n",
    "    sigma_ = torch.cat([sigma_0,sigma_1],dim=0)\n",
    "    \n",
    "    vec = torch.cat([df,B],dim=1)\n",
    "    \n",
    "    sigma = vec@sigma_@vec.t()\n",
    "    \n",
    "    return sigma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(func,state,control,grad=True):\n",
    "    F = func(state,control)\n",
    "    dF = None\n",
    "    if grad:\n",
    "        dF_ = [torch.autograd.grad(f,[state],grad_outputs=None,retain_graph=True,create_graph=False,only_inputs=True,allow_unused=True)[0].view(1,-1) for f in F]\n",
    "        dF  = torch.cat(dF_,dim=0)\n",
    "    return F,dF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process(state,control):\n",
    "    s0 = state[0]*control[0]\n",
    "    s1 = state[1]*control[1]\n",
    "    s2 = state[2]*control[1]\n",
    "    \n",
    "    return torch.cat([s0.view(1),s1.view(1),s2.view(1)],dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def error_model(state,control):\n",
    "    e0 = state[0]*control[0]*0.1\n",
    "    e1 = state[1]*control[1]*0.1\n",
    "    return torch.cat([e0.view(1),e1.view(1)],dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sigma_x = torch.eye(3)*2\n",
    "sigma_w = torch.eye(2)*2\n",
    "B = torch.cat([torch.eye(2),torch.zeros(1,2)],dim=0)\n",
    "B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "state = torch.nn.Parameter(torch.Tensor([1.,2.,3.]))\n",
    "print(state.requires_grad)\n",
    "control = torch.Tensor([2.,1.])\n",
    "f,df = evaluate(process,state,control)\n",
    "g,dg = evaluate(error_model,state,control)\n",
    "construct_sigma(df,dg,sigma_x,sigma_w,B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cubic_spline_planner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"LQR steering control tracking start!!\")\n",
    "ax = [0.0, 6.0, 12.5, 10.0, 7.5, 3.0, -1.0]\n",
    "ay = [0.0, -3.0, -5.0, 6.5, 3.0, 5.0, -2.0]\n",
    "goal = [ax[-1], ay[-1]]\n",
    "\n",
    "cx, cy, cyaw, ck, s = cubic_spline_planner.calc_spline_course(\n",
    "        ax, ay, ds=0.0001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "print(\"LQR steering control tracking start!!\")\n",
    "ax = [0.0, 6.0, 12.5, 10.0, 7.5, 3.0, -1.0]\n",
    "ay = [0.0, -3.0, -5.0, 6.5, 3.0, 5.0, -2.0]\n",
    "goal = [ax[-1], ay[-1]]\n",
    "\n",
    "cx, cy, cyaw, ck, s = cubic_spline_planner.calc_spline_course(\n",
    "        ax, ay, ds=0.0001)\n",
    "i=0000\n",
    "dx = (cx[i+1]-2*cx[i]+cx[i-1])/2\n",
    "dy = (cy[i+1]-2*cy[i]+cy[i-1])/2\n",
    "print(dy/dx,cyaw[i])\n",
    "plt.scatter(cx[i:],cy[i:])\n",
    "plt.scatter(ax,ay,c='r')\n",
    "gx=np.com\n",
    "ir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gy[0]/gx[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import numpy as np\n",
    "\n",
    "cx = np.array(cx)\n",
    "cy = np.array(cy)\n",
    "\n",
    "dx = cx[1:]-cx[:-1]\n",
    "dy = cy[1:]-cy[:-1]\n",
    "\n",
    "grad = dy/dx\n",
    "\n",
    "gx,gy = np.cos(grad),np.sin(grad)\n",
    "\n",
    "gx=np.append(gx,0)\n",
    "gy=np.append(gy,0)\n",
    "\n",
    "\n",
    "waypoints = torch.cat([torch.FloatTensor(cx).view(-1,1),torch.FloatTensor(cy).view(-1,1),torch.FloatTensor(gx).view(-1,1),torch.FloatTensor(gy).view(-1,1)] ,dim=1)\n",
    "waypoints.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "dx = (cx[i+1]-2*cx[i]+cx[i-1])/2\n",
    "dy = (cy[i+1]-2*cy[i]+cy[i-1])/2\n",
    "print(dy/dx,cyaw[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a=torch.zeros(10,4)\n",
    "indices=[1,4,5]\n",
    "a[indices,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "states = torch.ones(3,2)\n",
    "indices = [10,20,30]\n",
    "e = compute_error_parallel(states,waypoints,indices,1,2)\n",
    "e,waypoints[indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a=torch.Tensor([1.001])\n",
    "a.round()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_error_parallel(states,waypoints,indices,ql,qc):\n",
    "    x,y = states[:,0],states[:,1]\n",
    "    cx,cy,gx,gy = waypoints[indices,0],waypoints[indices,1],waypoints[indices,2],waypoints[indices,3]\n",
    "    el = -gx*(x-cx) - gy*(y-cy)\n",
    "    ec =  gy*(x-cx) - gx*(y-cy)\n",
    "    err = ql*el.pow(2).sum()+qc*ec.pow(2).sum()\n",
    "    return err"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_error(state,waypoints,indx):\n",
    "    x,y,phi,vx,vy,r,delta,T = state\n",
    "    cx,cy,gx,gy = waypoints[indx]   \n",
    "    el = -gx*(x-cx) - gy*(y-cy)\n",
    "    ec =  gy*(x-cx) - gx*(y-cy)\n",
    "    return el,ec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "L = 100\n",
    "\n",
    "def search_(state,waypoints,_indx):\n",
    "    x,y,phi,vx,vy,r,delta,T = state\n",
    "    cx,cy,gx,gy = waypoints\n",
    "    dx = x-cx[_indx:_indx+L]\n",
    "    dy = y-cy[_indx:_indx+L]\n",
    "    d2 = dx.pow(2) + dy.pow(2)\n",
    "    indx = torch.argmin(d2) + _indx\n",
    "\n",
    "    return indx_\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "waypoints = torch.Tensor(cx),torch.Tensor(cy),_,_\n",
    "states = torch.Tensor([0.00081]),torch.Tensor([-0.00041]),_,_,_,_,_,_\n",
    "ind = 7\n",
    "ind_ = compute_error(states,waypoints,ind)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ind_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "waypoints[0][10],waypoints[1][10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
